diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 5d89ba94ad..2ee343d8a4 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -67,7 +67,9 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
- self.hs._listen_http(parse_listener_def(0, config))
+ hs = self.hs
+ assert isinstance(hs, GenericWorkerServer)
+ hs._listen_http(parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
@@ -115,7 +117,9 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
- self.hs._listener_http(self.hs.config, parse_listener_def(0, config))
+ hs = self.hs
+ assert isinstance(hs, SynapseHomeServer)
+ hs._listener_http(self.hs.config, parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index febcc1499d..e2a3bad065 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast
+from typing import List, Optional, Sequence, Tuple, cast
from unittest.mock import Mock
from typing_extensions import TypeAlias
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.appservice import (
ApplicationService,
@@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock
from ..utils import MockClock
-if TYPE_CHECKING:
- from twisted.internet.testing import MemoryReactor
-
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self) -> None:
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 56193bc000..d6db5b6423 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -194,7 +194,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
"server9",
- time.time() * 1000,
+ int(time.time() * 1000),
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
)
self.get_success(r)
@@ -289,7 +289,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
"server9",
- time.time() * 1000,
+ int(time.time() * 1000),
# None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_verify_keys.
@@ -468,9 +468,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
- res = key_json[lookup_triplet]
- self.assertEqual(len(res), 1)
- res = res[0]
+ res_keys = key_json[lookup_triplet]
+ self.assertEqual(len(res_keys), 1)
+ res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], SERVER_NAME)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@@ -586,9 +586,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
- res = key_json[lookup_triplet]
- self.assertEqual(len(res), 1)
- res = res[0]
+ res_keys = key_json[lookup_triplet]
+ self.assertEqual(len(res_keys), 1)
+ res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@@ -707,9 +707,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
- res = key_json[lookup_triplet]
- self.assertEqual(len(res), 1)
- res = res[0]
+ res_keys = key_json[lookup_triplet]
+ self.assertEqual(len(res_keys), 1)
+ res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index a9893def74..6fb1f1bd6e 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -31,7 +31,11 @@ from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config
from tests.test_utils import simple_async_mock
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import (
+ FederatingHomeserverTestCase,
+ HomeserverTestCase,
+ override_config,
+)
@attr.s
@@ -152,11 +156,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
- fed_transport_client = Mock(spec=["send_transaction"])
- fed_transport_client.send_transaction = simple_async_mock({})
+ self.fed_transport_client = Mock(spec=["send_transaction"])
+ self.fed_transport_client.send_transaction = simple_async_mock({})
hs = self.setup_test_homeserver(
- federation_transport_client=fed_transport_client,
+ federation_transport_client=self.fed_transport_client,
)
load_legacy_presence_router(hs)
@@ -418,7 +422,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
#
# Thus we reset the mock, and try sending all online local user
# presence again
- self.hs.get_federation_transport_client().send_transaction.reset_mock()
+ self.fed_transport_client.send_transaction.reset_mock()
# Broadcast local user online presence
self.get_success(
@@ -443,9 +447,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
}
found_users = set()
- calls = (
- self.hs.get_federation_transport_client().send_transaction.call_args_list
- )
+ calls = self.fed_transport_client.send_transaction.call_args_list
for call in calls:
call_args = call[0]
federation_transaction: Transaction = call_args[0]
@@ -470,7 +472,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def send_presence_update(
- testcase: FederatingHomeserverTestCase,
+ testcase: HomeserverTestCase,
user_id: str,
access_token: str,
presence_state: str,
@@ -491,7 +493,7 @@ def send_presence_update(
def sync_presence(
- testcase: FederatingHomeserverTestCase,
+ testcase: HomeserverTestCase,
user_id: str,
since_token: Optional[StreamToken] = None,
) -> Tuple[List[UserPresenceState], StreamToken]:
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 9f1115dd23..33af8770fd 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -17,27 +17,25 @@ from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
from synapse.rest.client import login, room
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
-
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05}
return config
- def test_complexity_simple(self):
-
+ def test_complexity_simple(self) -> None:
u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass")
@@ -56,7 +54,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Artificially raise the complexity
store = self.hs.get_datastores().main
- store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
+
+ async def get_current_state_event_counts(room_id: str) -> int:
+ return int(500 * 1.23)
+
+ store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
# Get the room complexity again -- make sure it's our artificial value
channel = self.make_signed_federation_request(
@@ -66,8 +68,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)
- def test_join_too_large(self):
-
+ def test_join_too_large(self) -> None:
u1 = self.register_user("u1", "pass")
handler = self.hs.get_room_member_handler()
@@ -75,12 +76,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -95,7 +96,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def test_join_too_large_admin(self):
+ def test_join_too_large_admin(self) -> None:
# Check whether an admin can join if option "admins_can_join" is undefined,
# this option defaults to false, so the join should fail.
@@ -106,12 +107,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -126,8 +127,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def test_join_too_large_once_joined(self):
-
+ def test_join_too_large_once_joined(self) -> None:
u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass")
@@ -144,17 +144,18 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
# Artificially raise the complexity
- self.hs.get_datastores().main.get_current_state_event_counts = (
- lambda x: make_awaitable(600)
- )
+ async def get_current_state_event_counts(room_id: str) -> int:
+ return 600
+
+ self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
room_1,
UserID.from_string(u1),
@@ -180,7 +181,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
login.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config["limit_remote_rooms"] = {
"enabled": True,
@@ -189,7 +190,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
}
return config
- def test_join_too_large_no_admin(self):
+ def test_join_too_large_no_admin(self) -> None:
# A user which is not an admin should not be able to join a remote room
# which is too complex.
@@ -200,12 +201,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -220,7 +221,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def test_join_too_large_admin(self):
+ def test_join_too_large_admin(self) -> None:
# An admin should be able to join rooms where a complexity check fails.
u1 = self.register_user("u1", "pass", admin=True)
@@ -230,12 +231,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index b8fee72898..6381583c24 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -1,13 +1,21 @@
-from typing import List, Tuple
+from typing import Callable, List, Optional, Tuple
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.federation.sender import PerDestinationQueue, TransactionManager
-from synapse.federation.units import Edu
+from synapse.federation.sender import (
+ FederationSender,
+ PerDestinationQueue,
+ TransactionManager,
+)
+from synapse.federation.units import Edu, Transaction
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable
@@ -28,35 +36,47 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
login.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_transport_client = Mock(spec=["send_transaction"])
return self.setup_test_homeserver(
- federation_transport_client=Mock(spec=["send_transaction"]),
+ federation_transport_client=self.federation_transport_client,
)
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# stub out get_current_hosts_in_room
- state_handler = hs.get_state_handler()
+ state_storage_controller = hs.get_storage_controllers().state
# This mock is crucial for destination_rooms to be populated.
- state_handler.get_current_hosts_in_room = Mock(
- return_value=make_awaitable(["test", "host2"])
+ # TODO: this seems to no longer be the case---tests pass with this mock
+ # commented out.
+ state_storage_controller.get_current_hosts_in_room = Mock( # type: ignore[assignment]
+ return_value=make_awaitable({"test", "host2"})
)
# whenever send_transaction is called, record the pdu data
- self.pdus = []
- self.failed_pdus = []
+ self.pdus: List[JsonDict] = []
+ self.failed_pdus: List[JsonDict] = []
self.is_online = True
- self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.federation_transport_client.send_transaction.side_effect = (
self.record_transaction
)
+ federation_sender = hs.get_federation_sender()
+ assert isinstance(federation_sender, FederationSender)
+ self.federation_sender = federation_sender
+
def default_config(self) -> JsonDict:
config = super().default_config()
config["federation_sender_instances"] = None
return config
- async def record_transaction(self, txn, json_cb):
- if self.is_online:
+ async def record_transaction(
+ self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]]
+ ) -> JsonDict:
+ if json_cb is None:
+ # The tests seem to expect that this method raises in this situation.
+ raise Exception("Blank json_cb")
+ elif self.is_online:
data = json_cb()
self.pdus.extend(data["pdus"])
return {}
@@ -92,7 +112,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
)[0]
return {"event_id": event_id, "stream_ordering": stream_ordering}
- def test_catch_up_destination_rooms_tracking(self):
+ def test_catch_up_destination_rooms_tracking(self) -> None:
"""
Tests that we populate the `destination_rooms` table as needed.
"""
@@ -117,7 +137,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertEqual(row_2["event_id"], event_id_2)
self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
- def test_catch_up_last_successful_stream_ordering_tracking(self):
+ def test_catch_up_last_successful_stream_ordering_tracking(self) -> None:
"""
Tests that we populate the `destination_rooms` table as needed.
"""
@@ -174,7 +194,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
"Send succeeded but not marked as last_successful_stream_ordering",
)
- def test_catch_up_from_blank_state(self):
+ def test_catch_up_from_blank_state(self) -> None:
"""
Runs an overall test of federation catch-up from scratch.
Further tests will focus on more narrow aspects and edge-cases, but I
@@ -218,11 +238,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# let's delete the federation transmission queue
# (this pretends we are starting up fresh.)
self.assertFalse(
- self.hs.get_federation_sender()
- ._per_destination_queues["host2"]
- .transmission_loop_running
+ self.federation_sender._per_destination_queues[
+ "host2"
+ ].transmission_loop_running
)
- del self.hs.get_federation_sender()._per_destination_queues["host2"]
+ del self.federation_sender._per_destination_queues["host2"]
# let's also clear any backoffs
self.get_success(
@@ -261,16 +281,15 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
destination_tm: str,
pending_pdus: List[EventBase],
_pending_edus: List[Edu],
- ) -> bool:
+ ) -> None:
assert destination == destination_tm
results_list.extend(pending_pdus)
- return True # success!
- transaction_manager.send_new_transaction = fake_send
+ transaction_manager.send_new_transaction = fake_send # type: ignore[assignment]
return per_dest_queue, results_list
- def test_catch_up_loop(self):
+ def test_catch_up_loop(self) -> None:
"""
Tests the behaviour of _catch_up_transmission_loop.
"""
@@ -312,6 +331,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# also fetch event 5 so we know its last_successful_stream_ordering later
event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5))
+ assert event_2.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_2.internal_metadata.stream_ordering
@@ -334,7 +354,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
event_5.internal_metadata.stream_ordering,
)
- def test_catch_up_on_synapse_startup(self):
+ def test_catch_up_on_synapse_startup(self) -> None:
"""
Tests the behaviour of get_catch_up_outstanding_destinations and
_wake_destinations_needing_catchup.
@@ -412,18 +432,19 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# patch wake_destination to just count the destinations instead
woken = []
- def wake_destination_track(destination):
+ def wake_destination_track(destination: str) -> None:
woken.append(destination)
- self.hs.get_federation_sender().wake_destination = wake_destination_track
+ self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment]
# cancel the pre-existing timer for _wake_destinations_needing_catchup
# this is because we are calling it manually rather than waiting for it
# to be called automatically
- self.hs.get_federation_sender()._catchup_after_startup_timer.cancel()
+ assert self.federation_sender._catchup_after_startup_timer is not None
+ self.federation_sender._catchup_after_startup_timer.cancel()
self.get_success(
- self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0
+ self.federation_sender._wake_destinations_needing_catchup(), by=5.0
)
# ASSERT (_wake_destinations_needing_catchup):
@@ -432,7 +453,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# - all destinations are woken exactly once; they appear once in woken.
self.assertCountEqual(woken, server_names[:-1])
- def test_not_latest_event(self):
+ def test_not_latest_event(self) -> None:
"""Test that we send the latest event in the room even if its not ours."""
per_dest_queue, sent_pdus = self.make_fake_destination_queue()
@@ -465,6 +486,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
)
)
+ assert event_1.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_1.internal_metadata.stream_ordering
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index e67f405826..91694e4fca 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -36,7 +36,9 @@ class FederationClientTest(FederatingHomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
super().prepare(reactor, clock, homeserver)
# mock out the Agent used by the federation client, which is easier than
@@ -51,7 +53,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
self.creator = f"@creator:{self.OTHER_SERVER_NAME}"
self.test_room_id = "!room_id"
- def test_get_room_state(self):
+ def test_get_room_state(self) -> None:
# mock up some events to use in the response.
# In real life, these would have things in `prev_events` and `auth_events`, but that's
# a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
@@ -140,7 +142,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
["m.room.create", "m.room.member", "m.room.power_levels"],
)
- def test_get_pdu_returns_nothing_when_event_does_not_exist(self):
+ def test_get_pdu_returns_nothing_when_event_does_not_exist(self) -> None:
"""No event should be returned when the event does not exist"""
pulled_pdu_info = self.get_success(
self.hs.get_federation_client().get_pdu(
@@ -151,11 +153,11 @@ class FederationClientTest(FederatingHomeserverTestCase):
)
self.assertEqual(pulled_pdu_info, None)
- def test_get_pdu(self):
+ def test_get_pdu(self) -> None:
"""Test to make sure an event is returned by `get_pdu()`"""
self._get_pdu_once()
- def test_get_pdu_event_from_cache_is_pristine(self):
+ def test_get_pdu_event_from_cache_is_pristine(self) -> None:
"""Test that modifications made to events returned by `get_pdu()`
do not propagate back to to the internal cache (events returned should
be a copy).
@@ -176,7 +178,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
RoomVersions.V9,
)
)
- self.assertIsNotNone(pulled_pdu_info2)
+ assert pulled_pdu_info2 is not None
remote_pdu2 = pulled_pdu_info2.pdu
# Sanity check that we are working against the same event
@@ -224,7 +226,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
RoomVersions.V9,
)
)
- self.assertIsNotNone(pulled_pdu_info)
+ assert pulled_pdu_info is not None
remote_pdu = pulled_pdu_info.pdu
# check the right call got made to the agent
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 8692d8190f..9e104fd96a 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -11,18 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Callable, FrozenSet, List, Optional, Set
from unittest.mock import Mock
from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
+from synapse.federation.units import Transaction
+from synapse.handlers.device import DeviceHandler
from synapse.rest import admin
from synapse.rest.client import login
+from synapse.server import HomeServer
from synapse.types import JsonDict, ReadReceipt
+from synapse.util import Clock
from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase
@@ -36,16 +41,17 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
re-enabled for the main process.
"""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_transport_client = Mock(spec=["send_transaction"])
hs = self.setup_test_homeserver(
- federation_transport_client=Mock(spec=["send_transaction"]),
+ federation_transport_client=self.federation_transport_client,
)
- hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
+ hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
return_value=make_awaitable({"test", "host2"})
)
- hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (
+ hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment]
hs.get_storage_controllers().state.get_current_hosts_in_room
)
@@ -56,10 +62,8 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
config["federation_sender_instances"] = None
return config
- def test_send_receipts(self):
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ def test_send_receipts(self) -> None:
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
@@ -98,10 +102,8 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
],
)
- def test_send_receipts_thread(self):
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ def test_send_receipts_thread(self) -> None:
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
# Create receipts for:
@@ -174,12 +176,10 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
],
)
- def test_send_receipts_with_backoff(self):
+ def test_send_receipts_with_backoff(self) -> None:
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
@@ -272,51 +272,60 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
login.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_transport_client = Mock(
+ spec=["send_transaction", "query_user_devices"]
+ )
return self.setup_test_homeserver(
- federation_transport_client=Mock(
- spec=["send_transaction", "query_user_devices"]
- ),
+ federation_transport_client=self.federation_transport_client,
)
- def default_config(self):
+ def default_config(self) -> JsonDict:
c = super().default_config()
# Enable federation sending on the main process.
c["federation_sender_instances"] = None
return c
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
test_room_id = "!room:host1"
# stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the
# server thinks the user shares a room with `@user2:host2`
- def get_rooms_for_user(user_id):
- return defer.succeed({test_room_id})
+ def get_rooms_for_user(user_id: str) -> "defer.Deferred[FrozenSet[str]]":
+ return defer.succeed(frozenset({test_room_id}))
- hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
+ hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user # type: ignore[assignment]
- async def get_current_hosts_in_room(room_id):
+ async def get_current_hosts_in_room(room_id: str) -> Set[str]:
if room_id == test_room_id:
- return ["host2"]
+ return {"host2"}
+ else:
+ # TODO: We should fail the test when we encounter an unxpected room ID.
+ # We can't just use `self.fail(...)` here because the app code is greedy
+ # with `Exception` and will catch it before the test can see it.
+ return set()
- # TODO: We should fail the test when we encounter an unxpected room ID.
- # We can't just use `self.fail(...)` here because the app code is greedy
- # with `Exception` and will catch it before the test can see it.
+ hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment]
- hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self.device_handler = device_handler
# whenever send_transaction is called, record the edu data
- self.edus = []
- self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.edus: List[JsonDict] = []
+ self.federation_transport_client.send_transaction.side_effect = (
self.record_transaction
)
- def record_transaction(self, txn, json_cb):
+ def record_transaction(
+ self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
+ ) -> "defer.Deferred[JsonDict]":
+ assert json_cb is not None
data = json_cb()
self.edus.extend(data["edus"])
return defer.succeed({})
- def test_send_device_updates(self):
+ def test_send_device_updates(self) -> None:
"""Basic case: each device update should result in an EDU"""
# create a device
u1 = self.register_user("user", "pass")
@@ -340,12 +349,12 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertEqual(len(self.edus), 1)
self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
- def test_dont_send_device_updates_for_remote_users(self):
+ def test_dont_send_device_updates_for_remote_users(self) -> None:
"""Check that we don't send device updates for remote users"""
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
- self.hs.get_federation_transport_client().query_user_devices.return_value = (
+ self.federation_transport_client.query_user_devices.return_value = (
make_awaitable(
{
"stream_id": "1",
@@ -356,7 +365,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
)
self.get_success(
- self.hs.get_device_handler().device_list_updater.incoming_device_list_update(
+ self.device_handler.device_list_updater.incoming_device_list_update(
"host2",
{
"user_id": "@user2:host2",
@@ -379,7 +388,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
)
self.assertIn("D1", devices)
- def test_upload_signatures(self):
+ def test_upload_signatures(self) -> None:
"""Uploading signatures on some devices should produce updates for that user"""
e2e_handler = self.hs.get_e2e_keys_handler()
@@ -391,7 +400,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# expect two edus
self.assertEqual(len(self.edus), 2)
- stream_id = None
+ stream_id: Optional[int] = None
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
@@ -473,13 +482,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
if stream_id is not None:
- self.assertEqual(c["prev_id"], [stream_id])
+ self.assertEqual(c["prev_id"], [stream_id]) # type: ignore[unreachable]
self.assertGreaterEqual(c["stream_id"], stream_id)
stream_id = c["stream_id"]
devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2"}, devices)
- def test_delete_devices(self):
+ def test_delete_devices(self) -> None:
"""If devices are deleted, that should result in EDUs too"""
# create devices
@@ -499,9 +508,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -521,11 +528,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2", "D3"}, devices)
- def test_unreachable_server(self):
+ def test_unreachable_server(self) -> None:
"""If the destination server is unreachable, all the updates should get sent on
recovery
"""
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
@@ -535,9 +542,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3")
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -555,7 +560,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# for each device, there should be a single update
self.assertEqual(len(self.edus), 3)
- stream_id = None
+ stream_id: Optional[int] = None
for edu in self.edus:
self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
@@ -566,13 +571,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2", "D3"}, devices)
- def test_prune_outbound_device_pokes1(self):
+ def test_prune_outbound_device_pokes1(self) -> None:
"""If a destination is unreachable, and the updates are pruned, we should get
a single update.
This case tests the behaviour when the server has never been reachable.
"""
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
@@ -582,9 +587,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3")
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -615,7 +618,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# synapse uses an empty prev_id list to indicate "needs a full resync".
self.assertEqual(c["prev_id"], [])
- def test_prune_outbound_device_pokes2(self):
+ def test_prune_outbound_device_pokes2(self) -> None:
"""If a destination is unreachable, and the updates are pruned, we should get
a single update.
@@ -632,7 +635,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
# now the server goes offline
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
self.login("user", "pass", device_id="D2")
@@ -643,9 +646,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.reactor.advance(1)
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
self.assertGreaterEqual(mock_send_txn.call_count, 3)
@@ -741,7 +742,7 @@ def encode_pubkey(sk: SigningKey) -> str:
return key.encode_verify_key_base64(key.get_verify_key(sk))
-def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
+def build_device_dict(user_id: str, device_id: str, sk: SigningKey) -> JsonDict:
"""Build a dict representing the given device"""
return {
"user_id": user_id,
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index be719e49c0..6c7738d810 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -21,7 +21,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -34,7 +34,6 @@ from tests.unittest import override_config
class FederationServerTests(unittest.FederatingHomeserverTestCase):
-
servlets = [
admin.register_servlets,
room.register_servlets,
@@ -42,7 +41,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
]
@parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)])
- def test_bad_request(self, query_content):
+ def test_bad_request(self, query_content: bytes) -> None:
"""
Querying with bad data returns a reasonable error code.
"""
@@ -64,7 +63,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
class ServerACLsTestCase(unittest.TestCase):
- def test_blacklisted_server(self):
+ def test_blacklisted_server(self) -> None:
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
logging.info("ACL event: %s", e.content)
@@ -74,7 +73,7 @@ class ServerACLsTestCase(unittest.TestCase):
self.assertTrue(server_matches_acl_event("evil.com.au", e))
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
- def test_block_ip_literals(self):
+ def test_block_ip_literals(self) -> None:
e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
logging.info("ACL event: %s", e.content)
@@ -83,7 +82,7 @@ class ServerACLsTestCase(unittest.TestCase):
self.assertFalse(server_matches_acl_event("[1:2::]", e))
self.assertTrue(server_matches_acl_event("1:2:3:4", e))
- def test_wildcard_matching(self):
+ def test_wildcard_matching(self) -> None:
e = _create_acl_event({"allow": ["good*.com"]})
self.assertTrue(
server_matches_acl_event("good.com", e),
@@ -110,7 +109,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
login.register_servlets,
]
- def test_needs_to_be_in_room(self):
+ def test_needs_to_be_in_room(self) -> None:
"""/v1/state/<room_id> requires the server to be in the room"""
u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass")
@@ -131,7 +130,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self._storage_controllers = hs.get_storage_controllers()
@@ -157,7 +156,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body
- def test_send_join(self):
+ def test_send_join(self) -> None:
"""happy-path test of send_join"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user)
@@ -324,7 +323,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# is probably sufficient to reassure that the bucket is updated.
-def _create_acl_event(content):
+def _create_acl_event(content: JsonDict) -> EventBase:
return make_event_from_dict(
{
"room_id": "!a:b",
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
index e88e5d8bb3..55655de862 100644
--- a/tests/federation/transport/server/test__base.py
+++ b/tests/federation/transport/server/test__base.py
@@ -15,6 +15,8 @@
from http import HTTPStatus
from typing import Dict, List, Tuple
+from twisted.web.resource import Resource
+
from synapse.api.errors import Codes
from synapse.federation.transport.server import BaseFederationServlet
from synapse.federation.transport.server._base import Authenticator, _parse_auth_header
@@ -62,7 +64,7 @@ class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCa
path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}"
- def create_test_resource(self):
+ def create_test_resource(self) -> Resource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
resource = JsonResource(self.hs)
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index ff589c0b6c..70209ab090 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
-from typing import Dict, List
+from typing import Any, Dict, List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, JoinRules, Membership
-from synapse.api.room_versions import RoomVersions
-from synapse.events import builder
+from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.events import EventBase, builder
+from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import RoomAlias
+from synapse.util import Clock
from tests.test_utils import event_injection
from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase
@@ -197,7 +201,9 @@ class FederationKnockingTestCase(
login.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
# We're not going to be properly signing events as our remote homeserver is fake,
@@ -205,23 +211,29 @@ class FederationKnockingTestCase(
# Note that these checks are not relevant to this test case.
# Have this homeserver auto-approve all event signature checking.
- async def approve_all_signature_checking(_, pdu):
+ async def approve_all_signature_checking(
+ room_version: RoomVersion,
+ pdu: EventBase,
+ record_failure_callback: Any = None,
+ ) -> EventBase:
return pdu
- homeserver.get_federation_server()._check_sigs_and_hash = (
+ homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[assignment]
approve_all_signature_checking
)
# Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver.
- async def _check_event_auth(origin, event, context, *args, **kwargs):
- return context
+ async def _check_event_auth(
+ origin: Optional[str], event: EventBase, context: EventContext
+ ) -> None:
+ pass
- homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth
+ homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
return super().prepare(reactor, clock, homeserver)
- def test_room_state_returned_when_knocking(self):
+ def test_room_state_returned_when_knocking(self) -> None:
"""
Tests that specific, stripped state events from a room are returned after
a remote homeserver successfully knocks on a local room.
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index cfd550a04b..c4231f4aa9 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -20,7 +20,7 @@ from tests.unittest import DEBUG, override_config
class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
@override_config({"allow_public_rooms_over_federation": False})
- def test_blocked_public_room_list_over_federation(self):
+ def test_blocked_public_room_list_over_federation(self) -> None:
"""Test that unauthenticated requests to the public rooms directory 403 when
allow_public_rooms_over_federation is False.
"""
@@ -31,7 +31,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(403, channel.code)
@override_config({"allow_public_rooms_over_federation": True})
- def test_open_public_room_list_over_federation(self):
+ def test_open_public_room_list_over_federation(self) -> None:
"""Test that unauthenticated requests to the public rooms directory 200 when
allow_public_rooms_over_federation is True.
"""
@@ -42,7 +42,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(200, channel.code)
@DEBUG
- def test_edu_debugging_doesnt_explode(self):
+ def test_edu_debugging_doesnt_explode(self) -> None:
"""Sanity check incoming federation succeeds with `synapse.debug_8631` enabled.
Remove this when we strip out issue_8631_logger.
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 6f300b8e11..5569ccef8a 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, JoinRules
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
+from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@@ -296,3 +297,58 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[0][0]["user_agent"], "user_agent")
self.assertGreater(args[0][0]["last_seen"], 0)
self.assertNotIn("access_token", args[0][0])
+
+ def test_account_data(self) -> None:
+ """Tests that user account data get exported."""
+ # add account data
+ self.get_success(
+ self._store.add_account_data_for_user(self.user2, "m.global", {"a": 1})
+ )
+ self.get_success(
+ self._store.add_account_data_to_room(
+ self.user2, "test_room", "m.per_room", {"b": 2}
+ )
+ )
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ # two calls, one call for user data and one call for room data
+ writer.write_account_data.assert_called()
+
+ args = writer.write_account_data.call_args_list[0][0]
+ self.assertEqual(args[0], "global")
+ self.assertEqual(args[1]["m.global"]["a"], 1)
+
+ args = writer.write_account_data.call_args_list[1][0]
+ self.assertEqual(args[0], "test_room")
+ self.assertEqual(args[1]["m.per_room"]["b"], 2)
+
+ def test_media_ids(self) -> None:
+ """Tests that media's metadata get exported."""
+
+ self.get_success(
+ self._store.store_local_media(
+ media_id="media_1",
+ media_type="image/png",
+ time_now_ms=self.clock.time_msec(),
+ upload_name=None,
+ media_length=50,
+ user_id=UserID.from_string(self.user2),
+ )
+ )
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_media_id.assert_called_once()
+
+ args = writer.write_media_id.call_args[0]
+ self.assertEqual(args[0], "media_1")
+ self.assertEqual(args[1]["media_id"], "media_1")
+ self.assertEqual(args[1]["media_length"], 50)
+ self.assertGreater(args[1]["created_ts"], 0)
+ self.assertIsNone(args[1]["upload_name"])
+ self.assertIsNone(args[1]["last_access_ts"])
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index a7495ab21a..9014e60577 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock(
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
return_value=self._services
)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 2733719d82..63aad0d10c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("test_user", {})
request = _mock_request()
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO.
cas_response = CasResponse("test_user", {})
@@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("föö", {})
request = _mock_request()
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {})
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 95698bc275..6b4cba65d0 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# we should now have an unused alg1 key
- res = self.get_success(
+ fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, ["alg1"])
+ self.assertEqual(fallback_res, ["alg1"])
# claiming an OTK when no OTKs are available should return the fallback
# key
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# we shouldn't have any unused fallback keys again
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, [])
+ self.assertEqual(unused_res, [])
# claiming an OTK again should return the same fallback key
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
@@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, [])
+ self.assertEqual(unused_res, [])
# uploading a new fallback key should result in an unused fallback key
self.get_success(
@@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, ["alg1"])
+ self.assertEqual(unused_res, ["alg1"])
# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
@@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)
@@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
@@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
# upload two device keys, which will be signed later by the self-signing key
- device_key_1 = {
+ device_key_1: JsonDict = {
"user_id": local_user,
"device_id": "abc",
"algorithms": [
@@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
"signatures": {local_user: {"ed25519:abc": "base64+signature"}},
}
- device_key_2 = {
+ device_key_2: JsonDict = {
"user_id": local_user,
"device_id": "def",
"algorithms": [
@@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
+ device_handler = self.hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
e = self.get_failure(
- self.hs.get_device_handler().check_device_registered(
+ device_handler.check_device_registered(
user_id=local_user,
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
initial_device_display_name="new display name",
@@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_id = "xyz"
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
- device_key = {
+ device_key: JsonDict = {
"user_id": local_user,
"device_id": device_id,
"algorithms": [
@@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
- master_key = {
+ master_key: JsonDict = {
"user_id": local_user,
"usage": ["master"],
"keys": {"ed25519:" + master_pubkey: master_pubkey},
@@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# the first user
other_user = "@otherboris:" + self.hs.hostname
other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
- other_master_key = {
+ other_master_key: JsonDict = {
# private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
"user_id": other_user,
"usage": ["master"],
@@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
- self.hs.get_federation_client().query_client_keys = mock.Mock(
+ self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
@@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
- self.hs.get_federation_client().query_user_devices = mock.Mock(
+ self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"user_id": remote_user_id,
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 57675fa407..bf0862ed54 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# We mock out the FederationClient.backfill method, to pretend that a remote
# server has returned our fake event.
federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
- self.hs.get_federation_client().backfill = federation_client_backfill_mock
+ self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]
# We also mock the persist method with a side effect of itself. This allows us
# to track when it has been called while preserving its function.
persist_events_and_notify_mock = Mock(
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
)
- self.hs.get_federation_event_handler().persist_events_and_notify = (
+ self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment]
persist_events_and_notify_mock
)
@@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_client = fed_handler.federation_client
room_id = "!room:example.com"
- membership_event = make_event_from_dict(
- {
- "room_id": room_id,
- "type": "m.room.member",
- "sender": "@alice:test",
- "state_key": "@alice:test",
- "content": {"membership": "join"},
- },
- RoomVersions.V10,
- )
-
- mock_make_membership_event = Mock(
- return_value=make_awaitable(
- (
- "example.com",
- membership_event,
- RoomVersions.V10,
- )
- )
- )
EVENT_CREATE = make_event_from_dict(
{
@@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
},
room_version=RoomVersions.V10,
)
+ membership_event = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@alice:test",
+ "state_key": "@alice:test",
+ "content": {"membership": "join"},
+ "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id],
+ },
+ RoomVersions.V10,
+ )
+ mock_make_membership_event = Mock(
+ return_value=make_awaitable(
+ (
+ "example.com",
+ membership_event,
+ RoomVersions.V10,
+ )
+ )
+ )
mock_send_join = Mock(
return_value=make_awaitable(
SendJoinResult(
@@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Try to start another partial state sync.
# Nothing should happen.
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# End the partial state sync
@@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
# The next attempt to start the partial state sync should work.
is_partial_state = True
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
def test_partial_state_room_sync_restart(self) -> None:
@@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Fail the partial state sync.
@@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Start the partial state sync again.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Deduplicate another partial state sync.
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Fail the partial state sync.
@@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 3)
mock_sync_partial_state_room.assert_called_with(
initial_destination="hs3",
- other_destinations=["hs2"],
+ other_destinations={"hs2"},
room_id="room_id",
)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 70ea4d15d4..c067e5bfe3 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
+from synapse.state import StateResolutionStore
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict
from synapse.util import Clock
@@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True
persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
self.get_success(
persistence.persist_event(
prev_event,
@@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
- state_res_store=main_store,
+ state_res_store=StateResolutionStore(main_store),
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
@@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
rejected_power_levels_event.event_id,
],
event_map={},
- state_res_store=main_store,
+ state_res_store=StateResolutionStore(main_store),
full_conflicted_set=set(),
)
),
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index c4727ab917..9691d66b48 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
@@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_event_creation_handler()
- self._persist_event_storage_controller = (
- self.hs.get_storage_controllers().persistence
- )
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persist_event_storage_controller = persistence
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
- self.info = self.get_success(
+ info = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(
self.access_token,
)
)
- self.token_id = self.info.token_id
+ assert info is not None
+ self.token_id = info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
@@ -78,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
return memberEvent, memberEventContext
- def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
+ def _create_duplicate_event(
+ self, txn_id: str
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates.
"""
@@ -106,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_suitably_random"
- event1, context = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event1))
ret_event1 = self.get_success(
self.handler.handle_new_client_event(
@@ -118,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(event1.event_id, ret_event1.event_id)
- event2, context = self._create_duplicate_event(txn_id)
+ event2, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event2))
# We want to test that the deduplication at the persit event end works,
# so we want to make sure we test with different events.
@@ -139,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_event` directly also does the right
# thing.
- event3, context = self._create_duplicate_event(txn_id)
+ event3, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event3))
+
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
@@ -153,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_events` directly also does the right
# thing.
- event4, context = self._create_duplicate_event(txn_id)
+ event4, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event4))
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
@@ -173,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_else_suitably_random"
# Create two duplicate events to persist at the same time
- event1, context1 = self._create_duplicate_event(txn_id)
- event2, context2 = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context1 = self._create_duplicate_event(txn_id)
+ context1 = self.get_success(unpersisted_context1.persist(event1))
+ event2, unpersisted_context2 = self._create_duplicate_event(txn_id)
+ context2 = self.get_success(unpersisted_context2.persist(event2))
# Ensure their event IDs are different to start with
self.assertNotEqual(event1.event_id, event2.event_id)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index ef5311ce64..bb52b3b1af 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver()
self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
- self.hs_patcher.start()
+ self.hs_patcher.start() # type: ignore[attr-defined]
self.handler = hs.get_oidc_handler()
self.provider = self.handler._providers["oidc"]
@@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def tearDown(self) -> None:
- self.hs_patcher.stop()
+ self.hs_patcher.stop() # type: ignore[attr-defined]
return super().tearDown()
def reset_mocks(self) -> None:
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 0916de64f5..aa91bc0a3d 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
username: The username to use for the test.
registration: Whether to test with registration URLs.
"""
- self.hs.get_identity_handler().send_threepid_validation = Mock(
+ self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment]
return_value=make_awaitable(0),
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index b9332d97dc..aff1ec4758 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -62,7 +62,7 @@ class TestSpamChecker:
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class DenyAll(TestSpamChecker):
@@ -111,7 +111,7 @@ class TestLegacyRegistrationSpamChecker:
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
@@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self) -> None:
- self.store.count_monthly_users = Mock(
+ self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
@@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test"
- self.store.count_real_users = Mock(return_value=make_awaitable(1))
+ self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
self,
) -> None:
- self.store.count_real_users = Mock(return_value=make_awaitable(2))
+ self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly not federated.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["federatable"])
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "public")
@@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a public room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertEqual(room["join_rules"], "public")
# Both users should be in the room.
@@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join")
@@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join")
@@ -503,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Lower the permissions of the inviter.
event_creation_handler = self.hs.get_event_creation_handler()
requester = create_requester(inviter)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_creation_handler.create_event(
requester,
{
@@ -515,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
},
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_creation_handler.handle_new_client_event(
requester, events_and_context=[(event, context)]
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 9b1b8b9f13..b5c772a7ae 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO.
saml_response = FakeAuthnResponse(
@@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# mock out the error renderer too
sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None)
+ sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
request = _mock_request()
@@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None)
+ sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
# register a user to occupy the first-choice MXID
store = self.hs.get_datastores().main
@@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py
index 137deab138..d6f43a98fc 100644
--- a/tests/handlers/test_sso.py
+++ b/tests/handlers/test_sso.py
@@ -113,7 +113,6 @@ async def mock_get_file(
headers: Optional[RawHeaders] = None,
is_allowed_content_type: Optional[Callable[[str], bool]] = None,
) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
-
fake_response = FakeResponse(code=404)
if url == "http://my.server/me.png":
fake_response = FakeResponse(
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index f1a50c5bcb..d11ded6c5b 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -31,7 +31,6 @@ EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
class StatsRoomTests(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1fe9563c98..94518a7196 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
# we mock out the federation client too
- mock_federation_client = Mock(spec=["put_json"])
- mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
+ self.mock_federation_client = Mock(spec=["put_json"])
+ self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
@@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.mock_hs_notifier = Mock()
hs = self.setup_test_homeserver(
notifier=self.mock_hs_notifier,
- federation_http_client=mock_federation_client,
+ federation_http_client=self.mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)
@@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- put_json = self.hs.get_federation_http_client().put_json
- put_json.assert_called_once_with(
+ self.mock_federation_client.put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
@@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
- put_json = self.hs.get_federation_http_client().put_json
- put_json.assert_called_once_with(
+ self.mock_federation_client.put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 75fc5a17a4..a02c1c6227 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Tuple
+from typing import Any, Tuple
from unittest.mock import Mock, patch
from urllib.parse import quote
@@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService
from synapse.rest.client import login, register, room, user_directory
from synapse.server import HomeServer
from synapse.storage.roommember import ProfileInfo
-from synapse.types import create_requester
+from synapse.types import UserProfile, create_requester
from synapse.util import Clock
from tests import unittest
@@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
+# A spam checker which doesn't implement anything, so create a bare object.
+class UselessSpamChecker:
+ def __init__(self, config: Any):
+ pass
+
+
class UserDirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the UserDirectoryHandler.
@@ -186,6 +192,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, self.appservice.sender, tok=self.appservice.token)
self._check_only_one_user_in_directory(user, room)
+ def test_search_term_with_colon_in_it_does_not_raise(self) -> None:
+ """
+ Regression test: Test that search terms with colons in them are acceptable.
+ """
+ u1 = self.register_user("user1", "pass")
+ self.get_success(self.handler.search_users(u1, "haha:paamayim-nekudotayim", 10))
+
def test_user_not_in_users_table(self) -> None:
"""Unclear how it happens, but on matrix.org we've seen join events
for users who aren't in the users table. Test that we don't fall over
@@ -773,7 +786,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
- async def allow_all(user_profile: ProfileInfo) -> bool:
+ async def allow_all(user_profile: UserProfile) -> bool:
# Allow all users.
return False
@@ -787,7 +800,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- async def block_all(user_profile: ProfileInfo) -> bool:
+ async def block_all(user_profile: UserProfile) -> bool:
# All users are spammy.
return True
@@ -797,6 +810,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
+ @override_config(
+ {
+ "spam_checker": {
+ "module": "tests.handlers.test_user_directory.UselessSpamChecker"
+ }
+ }
+ )
def test_legacy_spam_checker(self) -> None:
"""
A spam checker without the expected method should be ignored.
@@ -825,11 +845,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
self.assertEqual(public_users, set())
- # Configure a spam checker.
- spam_checker = self.hs.get_spam_checker()
- # The spam checker doesn't need any methods, so create a bare object.
- spam_checker.spam_checker = object()
-
# We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
@@ -949,13 +964,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(
- self.hs.get_storage_controllers().persistence.persist_event(event, context)
- )
+ context = self.get_success(unpersisted_context.persist(event))
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.get_success(persistence.persist_event(event, context))
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
"""We've chosen to simplify the user directory's implementation by
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 093537adef..528cdee34b 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -19,13 +19,15 @@ from zope.interface import implementer
from OpenSSL import SSL
from OpenSSL.SSL import Connection
+from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.ssl import Certificate, trustRootFromCertificates
+from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
-def get_test_https_policy():
+def get_test_https_policy() -> BrowserLikePolicyForHTTPS:
"""Get a test IPolicyForHTTPS which trusts the test CA cert
Returns:
@@ -39,7 +41,7 @@ def get_test_https_policy():
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
-def get_test_ca_cert_file():
+def get_test_ca_cert_file() -> str:
"""Get the path to the test CA cert
The keypair is generated with:
@@ -51,7 +53,7 @@ def get_test_ca_cert_file():
return os.path.join(os.path.dirname(__file__), "ca.crt")
-def get_test_key_file():
+def get_test_key_file() -> str:
"""get the path to the test key
The key file is made with:
@@ -137,15 +139,20 @@ class TestServerTLSConnectionFactory:
"""An SSL connection creator which returns connections which present a certificate
signed by our test CA."""
- def __init__(self, sanlist):
+ def __init__(self, sanlist: List[bytes]):
"""
Args:
- sanlist: list[bytes]: a list of subjectAltName values for the cert
+ sanlist: a list of subjectAltName values for the cert
"""
self._cert_file = create_test_cert_file(sanlist)
- def serverConnectionForTLS(self, tlsProtocol):
+ def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connection:
ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_certificate_file(self._cert_file)
ctx.use_privatekey_file(get_test_key_file())
return Connection(ctx, None)
+
+
+# A dummy address, useful for tests that use FakeTransport and don't care about where
+# packets are going to/coming from.
+dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 992d8f94fd..eb7f53fee5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -14,7 +14,7 @@
import base64
import logging
import os
-from typing import Iterable, Optional
+from typing import Any, Awaitable, Callable, Generator, List, Optional, cast
from unittest.mock import Mock, patch
import treq
@@ -24,14 +24,19 @@ from zope.interface import implementer
from twisted.internet import defer
from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
-from twisted.internet.interfaces import IProtocolFactory
-from twisted.internet.protocol import Factory
+from twisted.internet.defer import Deferred
+from twisted.internet.endpoints import _WrappingProtocol
+from twisted.internet.interfaces import (
+ IOpenSSLClientConnectionCreator,
+ IProtocolFactory,
+)
+from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
from twisted.web.http import HTTPChannel, Request
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IPolicyForHTTPS
+from twisted.web.iweb import IPolicyForHTTPS, IResponse
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto.context_factory import FederationPolicyForHTTPS
@@ -42,27 +47,39 @@ from synapse.http.federation.well_known_resolver import (
WellKnownResolver,
_cache_period_from_headers,
)
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import (
+ SENTINEL_CONTEXT,
+ LoggingContext,
+ LoggingContextOrSentinel,
+ current_context,
+)
+from synapse.types import ISynapseReactor
from synapse.util.caches.ttlcache import TTLCache
from tests import unittest
-from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
+from tests.http import (
+ TestServerTLSConnectionFactory,
+ dummy_address,
+ get_test_ca_cert_file,
+)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
-from tests.utils import default_config
+from tests.utils import checked_cast, default_config
logger = logging.getLogger(__name__)
# Once Async Mocks or lambdas are supported this can go away.
-def generate_resolve_service(result):
- async def resolve_service(_):
+def generate_resolve_service(
+ result: List[Server],
+) -> Callable[[Any], Awaitable[List[Server]]]:
+ async def resolve_service(_: Any) -> List[Server]:
return result
return resolve_service
class MatrixFederationAgentTests(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
self.mock_resolver = Mock()
@@ -75,8 +92,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.tls_factory = FederationPolicyForHTTPS(config)
- self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
- self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+ self.well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache(
+ "test_cache", timer=self.reactor.seconds
+ )
+ self.had_well_known_cache: TTLCache[bytes, bool] = TTLCache(
+ "test_cache", timer=self.reactor.seconds
+ )
self.well_known_resolver = WellKnownResolver(
self.reactor,
Agent(self.reactor, contextFactory=self.tls_factory),
@@ -89,8 +110,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
self,
client_factory: IProtocolFactory,
ssl: bool = True,
- expected_sni: bytes = None,
- tls_sanlist: Optional[Iterable[bytes]] = None,
+ expected_sni: Optional[bytes] = None,
+ tls_sanlist: Optional[List[bytes]] = None,
) -> HTTPChannel:
"""Builds a test server, and completes the outgoing client connection
Args:
@@ -116,8 +137,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
if ssl:
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
- server_protocol = server_factory.buildProtocol(None)
-
+ server_protocol = server_factory.buildProtocol(dummy_address)
+ assert server_protocol is not None
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
@@ -125,7 +146,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
- client_protocol = client_factory.buildProtocol(None)
+ # NB: we use a checked_cast here to workaround https://github.com/Shoobx/mypy-zope/issues/91)
+ client_protocol = checked_cast(
+ _WrappingProtocol, client_factory.buildProtocol(dummy_address)
+ )
client_protocol.makeConnection(
FakeTransport(server_protocol, self.reactor, client_protocol)
)
@@ -136,6 +160,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
)
if ssl:
+ assert isinstance(server_protocol, TLSMemoryBIOProtocol)
# fish the test server back out of the server-side TLS protocol.
http_protocol = server_protocol.wrappedProtocol
# grab a hold of the TLS connection, in case it gets torn down
@@ -144,6 +169,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
http_protocol = server_protocol
tls_connection = None
+ assert isinstance(http_protocol, HTTPChannel)
# give the reactor a pump to get the TLS juices flowing (if needed)
self.reactor.advance(0)
@@ -159,12 +185,14 @@ class MatrixFederationAgentTests(unittest.TestCase):
return http_protocol
@defer.inlineCallbacks
- def _make_get_request(self, uri: bytes):
+ def _make_get_request(
+ self, uri: bytes
+ ) -> Generator["Deferred[object]", object, IResponse]:
"""
Sends a simple GET request via the agent, and checks its logcontext management
"""
with LoggingContext("one") as context:
- fetch_d = self.agent.request(b"GET", uri)
+ fetch_d: Deferred[IResponse] = self.agent.request(b"GET", uri)
# Nothing happened yet
self.assertNoResult(fetch_d)
@@ -172,8 +200,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# should have reset logcontext to the sentinel
_check_logcontext(SENTINEL_CONTEXT)
+ fetch_res: IResponse
try:
- fetch_res = yield fetch_d
+ fetch_res = yield fetch_d # type: ignore[misc, assignment]
return fetch_res
except Exception as e:
logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e)
@@ -216,7 +245,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
request: Request,
content: bytes,
headers: Optional[dict] = None,
- ):
+ ) -> None:
"""Check that an incoming request looks like a valid .well-known request, and
send back the response.
"""
@@ -237,16 +266,16 @@ class MatrixFederationAgentTests(unittest.TestCase):
because it is created too early during setUp
"""
return MatrixFederationAgent(
- reactor=self.reactor,
+ reactor=cast(ISynapseReactor, self.reactor),
tls_client_options_factory=self.tls_factory,
- user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
+ user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided.
ip_whitelist=IPSet(),
ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
)
- def test_get(self):
+ def test_get(self) -> None:
"""happy-path test of a GET request with an explicit port"""
self._do_get()
@@ -254,11 +283,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
os.environ,
{"https_proxy": "proxy.com", "no_proxy": "testserv"},
)
- def test_get_bypass_proxy(self):
+ def test_get_bypass_proxy(self) -> None:
"""test of a GET request with an explicit port and bypass proxy"""
self._do_get()
- def _do_get(self):
+ def _do_get(self) -> None:
"""test of a GET request with an explicit port"""
self.agent = self._make_agent()
@@ -318,7 +347,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
@patch.dict(
os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"}
)
- def test_get_via_http_proxy(self):
+ def test_get_via_http_proxy(self) -> None:
"""test for federation request through a http proxy"""
self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None)
@@ -326,7 +355,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
os.environ,
{"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"},
)
- def test_get_via_http_proxy_with_auth(self):
+ def test_get_via_http_proxy_with_auth(self) -> None:
"""test for federation request through a http proxy with authentication"""
self._do_get_via_proxy(
expect_proxy_ssl=False, expected_auth_credentials=b"user:pass"
@@ -335,7 +364,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
@patch.dict(
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
)
- def test_get_via_https_proxy(self):
+ def test_get_via_https_proxy(self) -> None:
"""test for federation request through a https proxy"""
self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None)
@@ -343,7 +372,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
os.environ,
{"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"},
)
- def test_get_via_https_proxy_with_auth(self):
+ def test_get_via_https_proxy_with_auth(self) -> None:
"""test for federation request through a https proxy with authentication"""
self._do_get_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=b"user:pass"
@@ -353,7 +382,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self,
expect_proxy_ssl: bool = False,
expected_auth_credentials: Optional[bytes] = None,
- ):
+ ) -> None:
"""Send a https federation request via an agent and check that it is correctly
received at the proxy and client. The proxy can use either http or https.
Args:
@@ -418,10 +447,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
# now we make another test server to act as the upstream HTTP server.
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
- ).buildProtocol(None)
+ ).buildProtocol(dummy_address)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
+ assert proxy_server_transport is not None
server_ssl_protocol.makeConnection(proxy_server_transport)
# ... and replace the protocol on the proxy's transport with the
@@ -436,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
- c2s_transport = client_protocol.transport
+ assert isinstance(client_protocol, Protocol)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
@@ -451,6 +482,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# now there should be a pending request
http_server = server_ssl_protocol.wrappedProtocol
+ assert isinstance(http_server, HTTPChannel)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
@@ -491,7 +523,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
json = self.successResultOf(treq.json_content(response))
self.assertEqual(json, {"a": 1})
- def test_get_ip_address(self):
+ def test_get_ip_address(self) -> None:
"""
Test the behaviour when the server name contains an explicit IP (with no port)
"""
@@ -526,7 +558,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
- def test_get_ipv6_address(self):
+ def test_get_ipv6_address(self) -> None:
"""
Test the behaviour when the server name contains an explicit IPv6 address
(with no port)
@@ -562,7 +594,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
- def test_get_ipv6_address_with_port(self):
+ def test_get_ipv6_address_with_port(self) -> None:
"""
Test the behaviour when the server name contains an explicit IPv6 address
(with explicit port)
@@ -598,7 +630,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
- def test_get_hostname_bad_cert(self):
+ def test_get_hostname_bad_cert(self) -> None:
"""
Test the behaviour when the certificate on the server doesn't match the hostname
"""
@@ -651,7 +683,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
failure_reason = e.value.reasons[0]
self.assertIsInstance(failure_reason.value, VerificationError)
- def test_get_ip_address_bad_cert(self):
+ def test_get_ip_address_bad_cert(self) -> None:
"""
Test the behaviour when the server name contains an explicit IP, but
the server cert doesn't cover it
@@ -684,7 +716,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
failure_reason = e.value.reasons[0]
self.assertIsInstance(failure_reason.value, VerificationError)
- def test_get_no_srv_no_well_known(self):
+ def test_get_no_srv_no_well_known(self) -> None:
"""
Test the behaviour when the server name has no port, no SRV, and no well-known
"""
@@ -740,7 +772,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
- def test_get_well_known(self):
+ def test_get_well_known(self) -> None:
"""Test the behaviour when the .well-known delegates elsewhere"""
self.agent = self._make_agent()
@@ -802,7 +834,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.well_known_cache.expire()
self.assertNotIn(b"testserv", self.well_known_cache)
- def test_get_well_known_redirect(self):
+ def test_get_well_known_redirect(self) -> None:
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect
"""
@@ -892,7 +924,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.well_known_cache.expire()
self.assertNotIn(b"testserv", self.well_known_cache)
- def test_get_invalid_well_known(self):
+ def test_get_invalid_well_known(self) -> None:
"""
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
"""
@@ -945,7 +977,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
- def test_get_well_known_unsigned_cert(self):
+ def test_get_well_known_unsigned_cert(self) -> None:
"""Test the behaviour when the .well-known server presents a cert
not signed by a CA
"""
@@ -969,7 +1001,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver(
- self.reactor,
+ cast(ISynapseReactor, self.reactor),
Agent(self.reactor, contextFactory=tls_factory),
b"test-agent",
well_known_cache=self.well_known_cache,
@@ -999,7 +1031,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
b"_matrix._tcp.testserv"
)
- def test_get_hostname_srv(self):
+ def test_get_hostname_srv(self) -> None:
"""
Test the behaviour when there is a single SRV record
"""
@@ -1041,7 +1073,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
- def test_get_well_known_srv(self):
+ def test_get_well_known_srv(self) -> None:
"""Test the behaviour when the .well-known redirects to a place where there
is a SRV.
"""
@@ -1101,7 +1133,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
- def test_idna_servername(self):
+ def test_idna_servername(self) -> None:
"""test the behaviour when the server name has idna chars in"""
self.agent = self._make_agent()
@@ -1163,7 +1195,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
- def test_idna_srv_target(self):
+ def test_idna_srv_target(self) -> None:
"""test the behaviour when the target of a SRV record has idna chars"""
self.agent = self._make_agent()
@@ -1206,7 +1238,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
- def test_well_known_cache(self):
+ def test_well_known_cache(self) -> None:
self.reactor.lookups["testserv"] = "1.2.3.4"
fetch_d = defer.ensureDeferred(
@@ -1262,7 +1294,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"other-server")
- def test_well_known_cache_with_temp_failure(self):
+ def test_well_known_cache_with_temp_failure(self) -> None:
"""Test that we refetch well-known before the cache expires, and that
it ignores transient errors.
"""
@@ -1341,7 +1373,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, None)
- def test_well_known_too_large(self):
+ def test_well_known_too_large(self) -> None:
"""A well-known query that returns a result which is too large should be rejected."""
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -1367,7 +1399,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
r = self.successResultOf(fetch_d)
self.assertIsNone(r.delegated_server)
- def test_srv_fallbacks(self):
+ def test_srv_fallbacks(self) -> None:
"""Test that other SRV results are tried if the first one fails."""
self.agent = self._make_agent()
@@ -1427,7 +1459,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
class TestCachePeriodFromHeaders(unittest.TestCase):
- def test_cache_control(self):
+ def test_cache_control(self) -> None:
# uppercase
self.assertEqual(
_cache_period_from_headers(
@@ -1464,7 +1496,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
0,
)
- def test_expires(self):
+ def test_expires(self) -> None:
self.assertEqual(
_cache_period_from_headers(
Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}),
@@ -1491,15 +1523,15 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0)
-def _check_logcontext(context):
+def _check_logcontext(context: LoggingContextOrSentinel) -> None:
current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
def _wrap_server_factory_for_tls(
- factory: IProtocolFactory, sanlist: Iterable[bytes] = None
-) -> IProtocolFactory:
+ factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
+) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
@@ -1537,7 +1569,7 @@ def _get_test_protocol_factory() -> IProtocolFactory:
return server_factory
-def _log_request(request: str):
+def _log_request(request: str) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info(f"Completed request {request}")
@@ -1547,6 +1579,8 @@ class TrustingTLSPolicyForHTTPS:
"""An IPolicyForHTTPS which checks that the certificate belongs to the
right server, but doesn't check the certificate chain."""
- def creatorForNetloc(self, hostname, port):
+ def creatorForNetloc(
+ self, hostname: bytes, port: int
+ ) -> IOpenSSLClientConnectionCreator:
certificateOptions = OpenSSLCertificateOptions()
return ClientTLSOptions(hostname, certificateOptions.getContext())
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 77ce8432ac..6ab13357f9 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import Dict, Generator, List, Tuple, cast
from unittest.mock import Mock
from twisted.internet import defer
@@ -20,7 +20,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.error import ConnectError
from twisted.names import dns, error
-from synapse.http.federation.srv_resolver import SrvResolver
+from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.logging.context import LoggingContext, current_context
from tests import unittest
@@ -28,7 +28,7 @@ from tests.utils import MockClock
class SrvResolverTestCase(unittest.TestCase):
- def test_resolve(self):
+ def test_resolve(self) -> None:
dns_client_mock = Mock()
service_name = b"test_service.example.com"
@@ -38,18 +38,18 @@ class SrvResolverTestCase(unittest.TestCase):
type=dns.SRV, payload=dns.Record_SRV(target=host_name)
)
- result_deferred = Deferred()
+ result_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
dns_client_mock.lookupService.return_value = result_deferred
- cache = {}
+ cache: Dict[bytes, List[Server]] = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@defer.inlineCallbacks
- def do_lookup():
-
+ def do_lookup() -> Generator["Deferred[object]", object, List[Server]]:
with LoggingContext("one") as ctx:
resolve_d = resolver.resolve_service(service_name)
- result = yield defer.ensureDeferred(resolve_d)
+ result: List[Server]
+ result = yield defer.ensureDeferred(resolve_d) # type: ignore[assignment]
# should have restored our context
self.assertIs(current_context(), ctx)
@@ -70,7 +70,9 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertEqual(servers[0].host, host_name)
@defer.inlineCallbacks
- def test_from_cache_expired_and_dns_fail(self):
+ def test_from_cache_expired_and_dns_fail(
+ self,
+ ) -> Generator["Deferred[object]", object, None]:
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
@@ -81,10 +83,13 @@ class SrvResolverTestCase(unittest.TestCase):
entry.priority = 0
entry.weight = 0
- cache = {service_name: [entry]}
+ cache = {service_name: [cast(Server, entry)]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
+ servers: List[Server]
+ servers = yield defer.ensureDeferred(
+ resolver.resolve_service(service_name)
+ ) # type: ignore[assignment]
dns_client_mock.lookupService.assert_called_once_with(service_name)
@@ -92,7 +97,7 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertEqual(servers, cache[service_name])
@defer.inlineCallbacks
- def test_from_cache(self):
+ def test_from_cache(self) -> Generator["Deferred[object]", object, None]:
clock = MockClock()
dns_client_mock = Mock(spec_set=["lookupService"])
@@ -105,12 +110,15 @@ class SrvResolverTestCase(unittest.TestCase):
entry.priority = 0
entry.weight = 0
- cache = {service_name: [entry]}
+ cache = {service_name: [cast(Server, entry)]}
resolver = SrvResolver(
dns_client=dns_client_mock, cache=cache, get_time=clock.time
)
- servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
+ servers: List[Server]
+ servers = yield defer.ensureDeferred(
+ resolver.resolve_service(service_name)
+ ) # type: ignore[assignment]
self.assertFalse(dns_client_mock.lookupService.called)
@@ -118,45 +126,48 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertEqual(servers, cache[service_name])
@defer.inlineCallbacks
- def test_empty_cache(self):
+ def test_empty_cache(self) -> Generator["Deferred[object]", object, None]:
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = b"test_service.example.com"
- cache = {}
+ cache: Dict[bytes, List[Server]] = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError):
yield defer.ensureDeferred(resolver.resolve_service(service_name))
@defer.inlineCallbacks
- def test_name_error(self):
+ def test_name_error(self) -> Generator["Deferred[object]", object, None]:
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
service_name = b"test_service.example.com"
- cache = {}
+ cache: Dict[bytes, List[Server]] = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
+ servers: List[Server]
+ servers = yield defer.ensureDeferred(
+ resolver.resolve_service(service_name)
+ ) # type: ignore[assignment]
self.assertEqual(len(servers), 0)
self.assertEqual(len(cache), 0)
- def test_disabled_service(self):
+ def test_disabled_service(self) -> None:
"""
test the behaviour when there is a single record which is ".".
"""
service_name = b"test_service.example.com"
- lookup_deferred = Deferred()
+ lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
- cache = {}
+ cache: Dict[bytes, List[Server]] = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
# Old versions of Twisted don't have an ensureDeferred in failureResultOf.
@@ -173,16 +184,16 @@ class SrvResolverTestCase(unittest.TestCase):
self.failureResultOf(resolve_d, ConnectError)
- def test_non_srv_answer(self):
+ def test_non_srv_answer(self) -> None:
"""
test the behaviour when the dns server gives us a spurious non-SRV response
"""
service_name = b"test_service.example.com"
- lookup_deferred = Deferred()
+ lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
- cache = {}
+ cache: Dict[bytes, List[Server]] = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
# Old versions of Twisted don't have an ensureDeferred in successResultOf.
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 5071f83574..36472e57a8 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -556,6 +556,6 @@ def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str:
return method_name
-def _hash_stack(stack: List[inspect.FrameInfo]):
+def _hash_stack(stack: List[inspect.FrameInfo]) -> Tuple[str, ...]:
"""Turns a stack into a hashable value that can be put into a set."""
return tuple(_format_stack_frame(frame) for frame in stack)
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 391196425c..ec6aacf235 100644
--- a/tests/http/test_additional_resource.py
+++ b/tests/http/test_additional_resource.py
@@ -11,28 +11,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+from twisted.web.server import Request
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from tests.server import FakeSite, make_request
from tests.unittest import HomeserverTestCase
class _AsyncTestCustomEndpoint:
- def __init__(self, config, module_api):
+ def __init__(self, config: JsonDict, module_api: Any) -> None:
pass
- async def handle_request(self, request):
+ async def handle_request(self, request: Request) -> None:
+ assert isinstance(request, SynapseRequest)
respond_with_json(request, 200, {"some_key": "some_value_async"})
class _SyncTestCustomEndpoint:
- def __init__(self, config, module_api):
+ def __init__(self, config: JsonDict, module_api: Any) -> None:
pass
- async def handle_request(self, request):
+ async def handle_request(self, request: Request) -> None:
+ assert isinstance(request, SynapseRequest)
respond_with_json(request, 200, {"some_key": "some_value_sync"})
@@ -41,7 +47,7 @@ class AdditionalResourceTests(HomeserverTestCase):
and async handlers.
"""
- def test_async(self):
+ def test_async(self) -> None:
handler = _AsyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)
@@ -52,7 +58,7 @@ class AdditionalResourceTests(HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
- def test_sync(self):
+ def test_sync(self) -> None:
handler = _SyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 7e2f2a01cc..57b6a84e23 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -13,10 +13,12 @@
# limitations under the License.
from io import BytesIO
+from typing import Tuple, Union
from unittest.mock import Mock
from netaddr import IPSet
+from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
@@ -28,6 +30,7 @@ from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BodyExceededMaxSize,
+ _DiscardBodyWithMaxSizeProtocol,
read_body_with_max_size,
)
@@ -36,7 +39,9 @@ from tests.unittest import TestCase
class ReadBodyWithMaxSizeTests(TestCase):
- def _build_response(self, length=UNKNOWN_LENGTH):
+ def _build_response(
+ self, length: Union[int, str] = UNKNOWN_LENGTH
+ ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
"""Start reading the body, returns the response, result and proto"""
response = Mock(length=length)
result = BytesIO()
@@ -48,23 +53,27 @@ class ReadBodyWithMaxSizeTests(TestCase):
return result, deferred, protocol
- def _assert_error(self, deferred, protocol):
+ def _assert_error(
+ self, deferred: "Deferred[int]", protocol: _DiscardBodyWithMaxSizeProtocol
+ ) -> None:
"""Ensure that the expected error is received."""
- self.assertIsInstance(deferred.result, Failure)
+ assert isinstance(deferred.result, Failure)
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
- protocol.transport.abortConnection.assert_called_once()
+ assert protocol.transport is not None
+ # type-ignore: presumably abortConnection has been replaced with a Mock.
+ protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined]
- def _cleanup_error(self, deferred):
+ def _cleanup_error(self, deferred: "Deferred[int]") -> None:
"""Ensure that the error in the Deferred is handled gracefully."""
called = [False]
- def errback(f):
+ def errback(f: Failure) -> None:
called[0] = True
deferred.addErrback(errback)
self.assertTrue(called[0])
- def test_no_error(self):
+ def test_no_error(self) -> None:
"""A response that is NOT too large."""
result, deferred, protocol = self._build_response()
@@ -76,7 +85,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self.assertEqual(result.getvalue(), b"12345")
self.assertEqual(deferred.result, 5)
- def test_too_large(self):
+ def test_too_large(self) -> None:
"""A response which is too large raises an exception."""
result, deferred, protocol = self._build_response()
@@ -87,7 +96,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
- def test_multiple_packets(self):
+ def test_multiple_packets(self) -> None:
"""Data should be accumulated through mutliple packets."""
result, deferred, protocol = self._build_response()
@@ -100,7 +109,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self.assertEqual(result.getvalue(), b"1234")
self.assertEqual(deferred.result, 4)
- def test_additional_data(self):
+ def test_additional_data(self) -> None:
"""A connection can receive data after being closed."""
result, deferred, protocol = self._build_response()
@@ -115,7 +124,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
- def test_content_length(self):
+ def test_content_length(self) -> None:
"""The body shouldn't be read (at all) if the Content-Length header is too large."""
result, deferred, protocol = self._build_response(length=10)
@@ -132,7 +141,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
class BlacklistingAgentTest(TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.reactor, self.clock = get_clock()
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
@@ -140,7 +149,7 @@ class BlacklistingAgentTest(TestCase):
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
# Configure the reactor's DNS resolver.
- for (domain, ip) in (
+ for domain, ip in (
(self.safe_domain, self.safe_ip),
(self.unsafe_domain, self.unsafe_ip),
(self.allowed_domain, self.allowed_ip),
@@ -151,7 +160,7 @@ class BlacklistingAgentTest(TestCase):
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
self.ip_blacklist = IPSet(["5.0.0.0/8"])
- def test_reactor(self):
+ def test_reactor(self) -> None:
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
agent = Agent(
BlacklistingReactorWrapper(
@@ -197,12 +206,12 @@ class BlacklistingAgentTest(TestCase):
response = self.successResultOf(d)
self.assertEqual(response.code, 200)
- def test_agent(self):
+ def test_agent(self) -> None:
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
agent = BlacklistingAgentWrapper(
Agent(self.reactor),
- ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
+ ip_whitelist=self.ip_whitelist,
)
# The unsafe IPs should be rejected.
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
index a801f002a0..8c18e56881 100644
--- a/tests/http/test_endpoint.py
+++ b/tests/http/test_endpoint.py
@@ -17,7 +17,7 @@ from tests import unittest
class ServerNameTestCase(unittest.TestCase):
- def test_parse_server_name(self):
+ def test_parse_server_name(self) -> None:
test_data = {
"localhost": ("localhost", None),
"my-example.com:1234": ("my-example.com", 1234),
@@ -32,7 +32,7 @@ class ServerNameTestCase(unittest.TestCase):
for i, o in test_data.items():
self.assertEqual(parse_server_name(i), o)
- def test_validate_bad_server_names(self):
+ def test_validate_bad_server_names(self) -> None:
test_data = [
"", # empty
"localhost:http", # non-numeric port
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index 9b64890bdb..9373523dea 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -11,16 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import Generator
from unittest.mock import Mock
from netaddr import IPSet
from parameterized import parameterized
from twisted.internet import defer
-from twisted.internet.defer import TimeoutError
+from twisted.internet.defer import Deferred, TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
-from twisted.test.proto_helpers import StringTransport
+from twisted.test.proto_helpers import MemoryReactor, StringTransport
from twisted.web.client import ResponseNeverReceived
from twisted.web.http import HTTPChannel
@@ -30,34 +30,43 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient,
MatrixFederationRequest,
)
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import (
+ SENTINEL_CONTEXT,
+ LoggingContext,
+ LoggingContextOrSentinel,
+ current_context,
+)
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase
-def check_logcontext(context):
+def check_logcontext(context: LoggingContextOrSentinel) -> None:
current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
class FederationClientTests(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
return hs
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.cl = MatrixFederationHttpClient(self.hs, None)
self.reactor.lookups["testserv"] = "1.2.3.4"
- def test_client_get(self):
+ def test_client_get(self) -> None:
"""
happy-path test of a GET request
"""
@defer.inlineCallbacks
- def do_request():
+ def do_request() -> Generator["Deferred[object]", object, object]:
with LoggingContext("one") as context:
fetch_d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar")
@@ -119,7 +128,7 @@ class FederationClientTests(HomeserverTestCase):
# check the response is as expected
self.assertEqual(res, {"a": 1})
- def test_dns_error(self):
+ def test_dns_error(self) -> None:
"""
If the DNS lookup returns an error, it will bubble up.
"""
@@ -132,7 +141,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
- def test_client_connection_refused(self):
+ def test_client_connection_refused(self) -> None:
d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
)
@@ -156,7 +165,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIs(f.value.inner_exception, e)
- def test_client_never_connect(self):
+ def test_client_never_connect(self) -> None:
"""
If the HTTP request is not connected and is timed out, it'll give a
ConnectingCancelledError or TimeoutError.
@@ -188,7 +197,7 @@ class FederationClientTests(HomeserverTestCase):
f.value.inner_exception, (ConnectingCancelledError, TimeoutError)
)
- def test_client_connect_no_response(self):
+ def test_client_connect_no_response(self) -> None:
"""
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
@@ -222,7 +231,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
- def test_client_ip_range_blacklist(self):
+ def test_client_ip_range_blacklist(self) -> None:
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
# Set up the ip_range blacklist
@@ -292,7 +301,7 @@ class FederationClientTests(HomeserverTestCase):
f = self.failureResultOf(d, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError)
- def test_client_gets_headers(self):
+ def test_client_gets_headers(self) -> None:
"""
Once the client gets the headers, _request returns successfully.
"""
@@ -319,7 +328,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertEqual(r.code, 200)
@parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
- def test_timeout_reading_body(self, method_name: str):
+ def test_timeout_reading_body(self, method_name: str) -> None:
"""
If the HTTP request is connected, but gets no response before being
timed out, it'll give a RequestSendFailed with can_retry.
@@ -351,7 +360,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertTrue(f.value.can_retry)
self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
- def test_client_requires_trailing_slashes(self):
+ def test_client_requires_trailing_slashes(self) -> None:
"""
If a connection is made to a client but the client rejects it due to
requiring a trailing slash. We need to retry the request with a
@@ -405,7 +414,7 @@ class FederationClientTests(HomeserverTestCase):
r = self.successResultOf(d)
self.assertEqual(r, {})
- def test_client_does_not_retry_on_400_plus(self):
+ def test_client_does_not_retry_on_400_plus(self) -> None:
"""
Another test for trailing slashes but now test that we don't retry on
trailing slashes on a non-400/M_UNRECOGNIZED response.
@@ -450,7 +459,7 @@ class FederationClientTests(HomeserverTestCase):
# We should get a 404 failure response
self.failureResultOf(d)
- def test_client_sends_body(self):
+ def test_client_sends_body(self) -> None:
defer.ensureDeferred( # type: ignore[unused-awaitable]
self.cl.post_json(
"testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
@@ -474,7 +483,7 @@ class FederationClientTests(HomeserverTestCase):
content = request.content.read()
self.assertEqual(content, b'{"a":"b"}')
- def test_closes_connection(self):
+ def test_closes_connection(self) -> None:
"""Check that the client closes unused HTTP connections"""
d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
@@ -514,7 +523,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertTrue(conn.disconnecting)
@parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
- def test_json_error(self, return_value):
+ def test_json_error(self, return_value: bytes) -> None:
"""
Test what happens if invalid JSON is returned from the remote endpoint.
"""
@@ -560,7 +569,7 @@ class FederationClientTests(HomeserverTestCase):
f = self.failureResultOf(test_d)
self.assertIsInstance(f.value, RequestSendFailed)
- def test_too_big(self):
+ def test_too_big(self) -> None:
"""
Test what happens if a huge response is returned from the remote endpoint.
"""
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 2db77c6a73..cc175052ac 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -14,7 +14,7 @@
import base64
import logging
import os
-from typing import Iterable, Optional
+from typing import List, Optional
from unittest.mock import patch
import treq
@@ -22,9 +22,13 @@ from netaddr import IPSet
from parameterized import parameterized
from twisted.internet import interfaces # noqa: F401
-from twisted.internet.endpoints import HostnameEndpoint, _WrapperEndpoint
+from twisted.internet.endpoints import (
+ HostnameEndpoint,
+ _WrapperEndpoint,
+ _WrappingProtocol,
+)
from twisted.internet.interfaces import IProtocol, IProtocolFactory
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
@@ -32,9 +36,14 @@ from synapse.http.client import BlacklistingReactorWrapper
from synapse.http.connectproxyclient import ProxyCredentials
from synapse.http.proxyagent import ProxyAgent, parse_proxy
-from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.http import (
+ TestServerTLSConnectionFactory,
+ dummy_address,
+ get_test_https_policy,
+)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
+from tests.utils import checked_cast
logger = logging.getLogger(__name__)
@@ -183,7 +192,7 @@ class ProxyParserTests(TestCase):
expected_hostname: bytes,
expected_port: int,
expected_credentials: Optional[bytes],
- ):
+ ) -> None:
"""
Tests that a given proxy URL will be broken into the components.
Args:
@@ -209,7 +218,7 @@ class ProxyParserTests(TestCase):
class MatrixFederationAgentTests(TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
def _make_connection(
@@ -218,7 +227,7 @@ class MatrixFederationAgentTests(TestCase):
server_factory: IProtocolFactory,
ssl: bool = False,
expected_sni: Optional[bytes] = None,
- tls_sanlist: Optional[Iterable[bytes]] = None,
+ tls_sanlist: Optional[List[bytes]] = None,
) -> IProtocol:
"""Builds a test server, and completes the outgoing client connection
@@ -244,7 +253,8 @@ class MatrixFederationAgentTests(TestCase):
if ssl:
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
- server_protocol = server_factory.buildProtocol(None)
+ server_protocol = server_factory.buildProtocol(dummy_address)
+ assert server_protocol is not None
# now, tell the client protocol factory to build the client protocol,
# and wire the output of said protocol up to the server via
@@ -252,7 +262,8 @@ class MatrixFederationAgentTests(TestCase):
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
- client_protocol = client_factory.buildProtocol(None)
+ client_protocol = client_factory.buildProtocol(dummy_address)
+ assert client_protocol is not None
client_protocol.makeConnection(
FakeTransport(server_protocol, self.reactor, client_protocol)
)
@@ -263,6 +274,7 @@ class MatrixFederationAgentTests(TestCase):
)
if ssl:
+ assert isinstance(server_protocol, TLSMemoryBIOProtocol)
http_protocol = server_protocol.wrappedProtocol
tls_connection = server_protocol._tlsConnection
else:
@@ -288,7 +300,7 @@ class MatrixFederationAgentTests(TestCase):
scheme: bytes,
hostname: bytes,
path: bytes,
- ):
+ ) -> None:
"""Runs a test case for a direct connection not going through a proxy.
Args:
@@ -319,6 +331,7 @@ class MatrixFederationAgentTests(TestCase):
ssl=is_https,
expected_sni=hostname if is_https else None,
)
+ assert isinstance(http_server, HTTPChannel)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@@ -339,34 +352,34 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
- def test_http_request(self):
+ def test_http_request(self) -> None:
agent = ProxyAgent(self.reactor)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
- def test_https_request(self):
+ def test_https_request(self) -> None:
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
- def test_http_request_use_proxy_empty_environment(self):
+ def test_http_request_use_proxy_empty_environment(self) -> None:
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
- def test_http_request_via_uppercase_no_proxy(self):
+ def test_http_request_via_uppercase_no_proxy(self) -> None:
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(
os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
)
- def test_http_request_via_no_proxy(self):
+ def test_http_request_via_no_proxy(self) -> None:
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(
os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
)
- def test_https_request_via_no_proxy(self):
+ def test_https_request_via_no_proxy(self) -> None:
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
@@ -375,12 +388,12 @@ class MatrixFederationAgentTests(TestCase):
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
- def test_http_request_via_no_proxy_star(self):
+ def test_http_request_via_no_proxy_star(self) -> None:
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
- def test_https_request_via_no_proxy_star(self):
+ def test_https_request_via_no_proxy_star(self) -> None:
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
@@ -389,7 +402,7 @@ class MatrixFederationAgentTests(TestCase):
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
- def test_http_request_via_proxy(self):
+ def test_http_request_via_proxy(self) -> None:
"""
Tests that requests can be made through a proxy.
"""
@@ -401,7 +414,7 @@ class MatrixFederationAgentTests(TestCase):
os.environ,
{"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"},
)
- def test_http_request_via_proxy_with_auth(self):
+ def test_http_request_via_proxy_with_auth(self) -> None:
"""
Tests that authenticated requests can be made through a proxy.
"""
@@ -412,7 +425,7 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(
os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"}
)
- def test_http_request_via_https_proxy(self):
+ def test_http_request_via_https_proxy(self) -> None:
self._do_http_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=None
)
@@ -424,13 +437,13 @@ class MatrixFederationAgentTests(TestCase):
"no_proxy": "unused.com",
},
)
- def test_http_request_via_https_proxy_with_auth(self):
+ def test_http_request_via_https_proxy_with_auth(self) -> None:
self._do_http_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
)
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
- def test_https_request_via_proxy(self):
+ def test_https_request_via_proxy(self) -> None:
"""Tests that TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(
expect_proxy_ssl=False, expected_auth_credentials=None
@@ -440,7 +453,7 @@ class MatrixFederationAgentTests(TestCase):
os.environ,
{"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
)
- def test_https_request_via_proxy_with_auth(self):
+ def test_https_request_via_proxy_with_auth(self) -> None:
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(
expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies"
@@ -449,7 +462,7 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
)
- def test_https_request_via_https_proxy(self):
+ def test_https_request_via_https_proxy(self) -> None:
"""Tests that TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=None
@@ -459,7 +472,7 @@ class MatrixFederationAgentTests(TestCase):
os.environ,
{"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
)
- def test_https_request_via_https_proxy_with_auth(self):
+ def test_https_request_via_https_proxy_with_auth(self) -> None:
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
@@ -469,7 +482,7 @@ class MatrixFederationAgentTests(TestCase):
self,
expect_proxy_ssl: bool = False,
expected_auth_credentials: Optional[bytes] = None,
- ):
+ ) -> None:
"""Send a http request via an agent and check that it is correctly received at
the proxy. The proxy can use either http or https.
Args:
@@ -501,6 +514,7 @@ class MatrixFederationAgentTests(TestCase):
tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
expected_sni=b"proxy.com" if expect_proxy_ssl else None,
)
+ assert isinstance(http_server, HTTPChannel)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@@ -542,7 +556,7 @@ class MatrixFederationAgentTests(TestCase):
self,
expect_proxy_ssl: bool = False,
expected_auth_credentials: Optional[bytes] = None,
- ):
+ ) -> None:
"""Send a https request via an agent and check that it is correctly received at
the proxy and client. The proxy can use either http or https.
Args:
@@ -606,10 +620,11 @@ class MatrixFederationAgentTests(TestCase):
# now we make another test server to act as the upstream HTTP server.
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
- ).buildProtocol(None)
+ ).buildProtocol(dummy_address)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
+ assert proxy_server_transport is not None
server_ssl_protocol.makeConnection(proxy_server_transport)
# ... and replace the protocol on the proxy's transport with the
@@ -629,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
- c2s_transport = client_protocol.transport
+ assert isinstance(client_protocol, Protocol)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
@@ -644,6 +660,7 @@ class MatrixFederationAgentTests(TestCase):
# now there should be a pending request
http_server = server_ssl_protocol.wrappedProtocol
+ assert isinstance(http_server, HTTPChannel)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
@@ -667,7 +684,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
- def test_http_request_via_proxy_with_blacklist(self):
+ def test_http_request_via_proxy_with_blacklist(self) -> None:
# The blacklist includes the configured proxy IP.
agent = ProxyAgent(
BlacklistingReactorWrapper(
@@ -691,6 +708,7 @@ class MatrixFederationAgentTests(TestCase):
http_server = self._make_connection(
client_factory, _get_test_protocol_factory()
)
+ assert isinstance(http_server, HTTPChannel)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@@ -712,7 +730,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
- def test_https_request_via_uppercase_proxy_with_blacklist(self):
+ def test_https_request_via_uppercase_proxy_with_blacklist(self) -> None:
# The blacklist includes the configured proxy IP.
agent = ProxyAgent(
BlacklistingReactorWrapper(
@@ -737,11 +755,17 @@ class MatrixFederationAgentTests(TestCase):
proxy_server = self._make_connection(
client_factory, _get_test_protocol_factory()
)
+ assert isinstance(proxy_server, HTTPChannel)
# fish the transports back out so that we can do the old switcheroo
- s2c_transport = proxy_server.transport
- client_protocol = s2c_transport.other
- c2s_transport = client_protocol.transport
+ # To help mypy out with the various Protocols and wrappers and mocks, we do
+ # some explicit casting. Without the casts, we hit the bug I reported at
+ # https://github.com/Shoobx/mypy-zope/issues/91 .
+ # We also double-checked these casts at runtime (test-time) because I found it
+ # quite confusing to deduce these types in the first place!
+ s2c_transport = checked_cast(FakeTransport, proxy_server.transport)
+ client_protocol = checked_cast(_WrappingProtocol, s2c_transport.other)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@@ -762,8 +786,10 @@ class MatrixFederationAgentTests(TestCase):
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
- ssl_protocol = ssl_factory.buildProtocol(None)
+ ssl_protocol = ssl_factory.buildProtocol(dummy_address)
+ assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
http_server = ssl_protocol.wrappedProtocol
+ assert isinstance(http_server, HTTPChannel)
ssl_protocol.makeConnection(
FakeTransport(client_protocol, self.reactor, ssl_protocol)
@@ -797,39 +823,35 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
- def test_proxy_with_no_scheme(self):
+ def test_proxy_with_no_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+ proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
- def test_proxy_with_unsupported_scheme(self):
+ def test_proxy_with_unsupported_scheme(self) -> None:
with self.assertRaises(ValueError):
ProxyAgent(self.reactor, use_proxy=True)
@patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"})
- def test_proxy_with_http_scheme(self):
+ def test_proxy_with_http_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+ proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
- def test_proxy_with_https_scheme(self):
+ def test_proxy_with_https_scheme(self) -> None:
https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
- self.assertEqual(
- https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com"
- )
- self.assertEqual(
- https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888
- )
+ proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._wrappedEndpoint._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
def _wrap_server_factory_for_tls(
- factory: IProtocolFactory, sanlist: Iterable[bytes] = None
-) -> IProtocolFactory:
+ factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
+) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
@@ -865,6 +887,6 @@ def _get_test_protocol_factory() -> IProtocolFactory:
return server_factory
-def _log_request(request: str):
+def _log_request(request: str) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info(f"Completed request {request}")
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index 46166292fe..c8d215b6dc 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -14,7 +14,7 @@
import json
from http import HTTPStatus
from io import BytesIO
-from typing import Tuple
+from typing import Tuple, Union
from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
@@ -33,7 +33,7 @@ from tests import unittest
from tests.http.server._base import test_disconnect
-def make_request(content):
+def make_request(content: Union[bytes, JsonDict]) -> Mock:
"""Make an object that acts enough like a request."""
request = Mock(spec=["method", "uri", "content"])
@@ -47,7 +47,7 @@ def make_request(content):
class TestServletUtils(unittest.TestCase):
- def test_parse_json_value(self):
+ def test_parse_json_value(self) -> None:
"""Basic tests for parse_json_value_from_request."""
# Test round-tripping.
obj = {"foo": 1}
@@ -78,7 +78,7 @@ class TestServletUtils(unittest.TestCase):
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
- def test_parse_json_object(self):
+ def test_parse_json_object(self) -> None:
"""Basic tests for parse_json_object_from_request."""
# Test empty.
result = parse_json_object_from_request(
diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
index c85a3665c1..010601da4b 100644
--- a/tests/http/test_simple_client.py
+++ b/tests/http/test_simple_client.py
@@ -17,22 +17,24 @@ from netaddr import IPSet
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
+from twisted.test.proto_helpers import MemoryReactor
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class SimpleHttpClientTests(HomeserverTestCase):
- def prepare(self, reactor, clock, hs: "HomeServer"):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: "HomeServer") -> None:
# Add a DNS entry for a test server
self.reactor.lookups["testserv"] = "1.2.3.4"
self.cl = hs.get_simple_http_client()
- def test_dns_error(self):
+ def test_dns_error(self) -> None:
"""
If the DNS lookup returns an error, it will bubble up.
"""
@@ -42,7 +44,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
f = self.failureResultOf(d)
self.assertIsInstance(f.value, DNSLookupError)
- def test_client_connection_refused(self):
+ def test_client_connection_refused(self) -> None:
d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
self.pump()
@@ -63,7 +65,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
self.assertIs(f.value, e)
- def test_client_never_connect(self):
+ def test_client_never_connect(self) -> None:
"""
If the HTTP request is not connected and is timed out, it'll give a
ConnectingCancelledError or TimeoutError.
@@ -90,7 +92,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestTimedOutError)
- def test_client_connect_no_response(self):
+ def test_client_connect_no_response(self) -> None:
"""
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
@@ -121,7 +123,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestTimedOutError)
- def test_client_ip_range_blacklist(self):
+ def test_client_ip_range_blacklist(self) -> None:
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
# Add some DNS entries we'll blacklist
diff --git a/tests/http/test_site.py b/tests/http/test_site.py
index b2dbf76d33..9a78fede92 100644
--- a/tests/http/test_site.py
+++ b/tests/http/test_site.py
@@ -13,18 +13,20 @@
# limitations under the License.
from twisted.internet.address import IPv6Address
-from twisted.test.proto_helpers import StringTransport
+from twisted.test.proto_helpers import MemoryReactor, StringTransport
from synapse.app.homeserver import SynapseHomeServer
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class SynapseRequestTestCase(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
- def test_large_request(self):
+ def test_large_request(self) -> None:
"""overlarge HTTP requests should be rejected"""
self.hs.start_listening()
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index c08954d887..5191e31a8a 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -21,6 +21,7 @@ from synapse.logging import RemoteHandler
from tests.logging import LoggerCleanupMixin
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
+from tests.utils import checked_cast
def connect_logging_client(
@@ -56,8 +57,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
client, server = connect_logging_client(self.reactor, 0)
# Trigger data being sent
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# One log message, with a single trailing newline
logs = server.data.decode("utf8").splitlines()
@@ -89,8 +90,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# Only the 7 infos made it through, the debugs were elided
logs = server.data.splitlines()
@@ -123,8 +124,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# The 10 warnings made it through, the debugs and infos were elided
logs = server.data.splitlines()
@@ -148,8 +149,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# The first five and last five warnings made it through, the debugs and
# infos were elided
diff --git a/tests/rest/media/v1/__init__.py b/tests/media/__init__.py
index b1ee10cfcc..68910cbf5b 100644
--- a/tests/rest/media/v1/__init__.py
+++ b/tests/media/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/rest/media/v1/test_base.py b/tests/media/test_base.py
index c73179151a..66498c744d 100644
--- a/tests/rest/media/v1/test_base.py
+++ b/tests/media/test_base.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.media._base import get_filename_from_headers
from tests import unittest
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/media/test_filepath.py
index 43e6f0f70a..95e3b83d5a 100644
--- a/tests/rest/media/v1/test_filepath.py
+++ b/tests/media/test_filepath.py
@@ -15,7 +15,7 @@ import inspect
import os
from typing import Iterable
-from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check
+from synapse.media.filepath import MediaFilePaths, _wrap_with_jail_check
from tests import unittest
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/media/test_html_preview.py
index 1062081a06..e7da75db3e 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/media/test_html_preview.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.rest.media.v1.preview_html import (
+from synapse.media.preview_html import (
_get_html_media_encodings,
decode_body,
parse_html_to_open_graph,
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/media/test_media_storage.py
index d18fc13c21..870047d0f2 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
import tempfile
from binascii import unhexlify
from io import BytesIO
-from typing import Any, BinaryIO, Dict, List, Optional, Union
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib import parse
@@ -32,16 +32,17 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
+from synapse.media._base import FileInfo
+from synapse.media.filepath import MediaFilePaths
+from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
+from synapse.media.storage_provider import FileStorageProviderBackend
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login
-from synapse.rest.media.v1._base import FileInfo
-from synapse.rest.media.v1.filepath import MediaFilePaths
-from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
-from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from synapse.server import HomeServer
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock
from tests import unittest
@@ -51,7 +52,6 @@ from tests.utils import default_config
class MediaStorageTests(unittest.HomeserverTestCase):
-
needs_threadpool = True
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -201,36 +201,45 @@ class _TestImage:
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
-
+ test_image: ClassVar[_TestImage]
hijack_auth = True
user_id = "@test:user"
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
- self.fetches = []
+ self.fetches: List[
+ Tuple[
+ "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]",
+ str,
+ str,
+ Optional[QueryParams],
+ ]
+ ] = []
def get_file(
destination: str,
path: str,
output_stream: BinaryIO,
- args: Optional[Dict[str, Union[str, List[str]]]] = None,
+ args: Optional[QueryParams] = None,
+ retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
- ) -> Deferred:
- """
- Returns tuple[int,dict,str,int] of file length, response headers,
- absolute URI, and response code.
- """
+ ignore_backoff: bool = False,
+ ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
+ """A mock for MatrixFederationHttpClient.get_file."""
- def write_to(r):
+ def write_to(
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+ ) -> Tuple[int, Dict[bytes, List[bytes]]]:
data, response = r
output_stream.write(data)
return response
- d = Deferred()
- d.addCallback(write_to)
+ d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args))
- return make_deferred_yieldable(d)
+ # Note that this callback changes the value held by d.
+ d_after_callback = d.addCallback(write_to)
+ return make_deferred_yieldable(d_after_callback)
+ # Mock out the homeserver's MatrixFederationHttpClient
client = Mock()
client.get_file = get_file
@@ -244,7 +253,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
config["max_image_pixels"] = 2000000
provider_config = {
- "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
@@ -257,7 +266,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
media_resource = hs.get_media_repository_resource()
self.download_resource = media_resource.children[b"download"]
self.thumbnail_resource = media_resource.children[b"thumbnail"]
@@ -461,6 +469,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
# Synapse should regenerate missing thumbnails.
origin, media_id = self.media_id.split("/")
info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
+ assert info is not None
file_id = info["filesystem_id"]
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
@@ -581,7 +590,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"thumbnail_method": method,
"thumbnail_type": self.test_image.content_type,
"thumbnail_length": 256,
- "filesystem_id": f"thumbnail1{self.test_image.extension}",
+ "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}",
},
{
"thumbnail_width": 32,
@@ -589,10 +598,10 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"thumbnail_method": method,
"thumbnail_type": self.test_image.content_type,
"thumbnail_length": 256,
- "filesystem_id": f"thumbnail2{self.test_image.extension}",
+ "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}",
},
],
- file_id=f"image{self.test_image.extension}",
+ file_id=f"image{self.test_image.extension.decode()}",
url_cache=None,
server_name=None,
)
@@ -637,6 +646,7 @@ class TestSpamCheckerLegacy:
self.config = config
self.api = api
+ @staticmethod
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config
@@ -748,7 +758,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
- ) -> Union[Codes, Literal["NOT_SPAM"]]:
+ ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write)
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/media/test_oembed.py
index 3f7f1dbab9..c8bf8421da 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/media/test_oembed.py
@@ -18,7 +18,7 @@ from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
-from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
+from synapse.media.oembed import OEmbedProvider, OEmbedResult
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 8f88c0117d..3a1929691e 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import NotFoundError
@@ -21,9 +23,12 @@ from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.handlers.push_rules import InvalidRuleException
+from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, notifications, presence, profile, room
-from synapse.types import create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, create_requester
+from synapse.util import Clock
from tests.events.test_presence_router import send_presence_update, sync_presence
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -32,7 +37,19 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
-class ModuleApiTestCase(HomeserverTestCase):
+class BaseModuleApiTestCase(HomeserverTestCase):
+ """Common properties of the two test case classes."""
+
+ module_api: ModuleApi
+
+ # These are all written by _test_sending_local_online_presence_to_local_user.
+ presence_receiver_id: str
+ presence_receiver_tok: str
+ presence_sender_id: str
+ presence_sender_tok: str
+
+
+class ModuleApiTestCase(BaseModuleApiTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -42,23 +59,23 @@ class ModuleApiTestCase(HomeserverTestCase):
notifications.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
- self.store = homeserver.get_datastores().main
- self.module_api = homeserver.get_module_api()
- self.event_creation_handler = homeserver.get_event_creation_handler()
- self.sync_handler = homeserver.get_sync_handler()
- self.auth_handler = homeserver.get_auth_handler()
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.module_api = hs.get_module_api()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.sync_handler = hs.get_sync_handler()
+ self.auth_handler = hs.get_auth_handler()
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
- fed_transport_client = Mock(spec=["send_transaction"])
- fed_transport_client.send_transaction = simple_async_mock({})
+ self.fed_transport_client = Mock(spec=["send_transaction"])
+ self.fed_transport_client.send_transaction = simple_async_mock({})
return self.setup_test_homeserver(
- federation_transport_client=fed_transport_client,
+ federation_transport_client=self.fed_transport_client,
)
- def test_can_register_user(self):
+ def test_can_register_user(self) -> None:
"""Tests that an external module can register a user"""
# Register a new user
user_id, access_token = self.get_success(
@@ -88,16 +105,17 @@ class ModuleApiTestCase(HomeserverTestCase):
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
- def test_can_register_admin_user(self):
+ def test_can_register_admin_user(self) -> None:
user_id = self.register_user(
"bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
)
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
- def test_can_set_admin(self):
+ def test_can_set_admin(self) -> None:
user_id = self.register_user(
"alice_wants_admin",
"1234",
@@ -107,16 +125,17 @@ class ModuleApiTestCase(HomeserverTestCase):
self.get_success(self.module_api.set_user_admin(user_id, True))
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
- def test_can_set_displayname(self):
+ def test_can_set_displayname(self) -> None:
localpart = "alice_wants_a_new_displayname"
user_id = self.register_user(
localpart, "1234", displayname="Alice", admin=False
)
found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id))
-
+ assert found_userinfo is not None
self.get_success(
self.module_api.set_displayname(
found_userinfo.user_id, "Bob", deactivation=False
@@ -128,17 +147,18 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(found_profile.display_name, "Bob")
- def test_get_userinfo_by_id(self):
+ def test_get_userinfo_by_id(self) -> None:
user_id = self.register_user("alice", "1234")
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, False)
- def test_get_userinfo_by_id__no_user_found(self):
+ def test_get_userinfo_by_id__no_user_found(self) -> None:
found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
self.assertIsNone(found_user)
- def test_get_user_ip_and_agents(self):
+ def test_get_user_ip_and_agents(self) -> None:
user_id = self.register_user("test_get_user_ip_and_agents_user", "1234")
# Initially, we should have no ip/agent for our user.
@@ -185,7 +205,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# we should only find the second ip, agent.
info = self.get_success(
self.module_api.get_user_ip_and_agents(
- user_id, (last_seen_1 + last_seen_2) / 2
+ user_id, (last_seen_1 + last_seen_2) // 2
)
)
self.assertEqual(len(info), 1)
@@ -200,7 +220,7 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertEqual(info, [])
- def test_get_user_ip_and_agents__no_user_found(self):
+ def test_get_user_ip_and_agents__no_user_found(self) -> None:
info = self.get_success(
self.module_api.get_user_ip_and_agents(
"@test_get_user_ip_and_agents_user_nonexistent:example.com"
@@ -208,10 +228,10 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertEqual(info, [])
- def test_sending_events_into_room(self):
+ def test_sending_events_into_room(self) -> None:
"""Tests that a module can send events into a room"""
# Mock out create_and_send_nonmember_event to check whether events are being sent
- self.event_creation_handler.create_and_send_nonmember_event = Mock(
+ self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[assignment]
spec=[],
side_effect=self.event_creation_handler.create_and_send_nonmember_event,
)
@@ -222,7 +242,7 @@ class ModuleApiTestCase(HomeserverTestCase):
room_id = self.helper.create_room_as(user_id, tok=tok)
# Create and send a non-state event
- content = {"body": "I am a puppet", "msgtype": "m.text"}
+ content: JsonDict = {"body": "I am a puppet", "msgtype": "m.text"}
event_dict = {
"room_id": room_id,
"type": "m.room.message",
@@ -265,7 +285,7 @@ class ModuleApiTestCase(HomeserverTestCase):
"sender": user_id,
"state_key": "",
}
- event: EventBase = self.get_success(
+ event = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict)
)
self.assertEqual(event.sender, user_id)
@@ -303,7 +323,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.create_and_send_event_into_room(event_dict), Exception
)
- def test_public_rooms(self):
+ def test_public_rooms(self) -> None:
"""Tests that a room can be added and removed from the public rooms list,
as well as have its public rooms directory state queried.
"""
@@ -350,13 +370,13 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertFalse(is_in_public_rooms)
- def test_send_local_online_presence_to(self):
+ def test_send_local_online_presence_to(self) -> None:
# Test sending local online presence to users from the main process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
# Enable federation sending on the main process.
@override_config({"federation_sender_instances": None})
- def test_send_local_online_presence_to_federation(self):
+ def test_send_local_online_presence_to_federation(self) -> None:
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
# Create a user who will send presence updates
self.presence_sender_id = self.register_user("presence_sender1", "monkey")
@@ -397,7 +417,7 @@ class ModuleApiTestCase(HomeserverTestCase):
#
# Thus we reset the mock, and try sending online local user
# presence again
- self.hs.get_federation_transport_client().send_transaction.reset_mock()
+ self.fed_transport_client.send_transaction.reset_mock()
# Broadcast local user online presence
self.get_success(
@@ -409,9 +429,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that a presence update was sent as part of a federation transaction
found_update = False
- calls = (
- self.hs.get_federation_transport_client().send_transaction.call_args_list
- )
+ calls = self.fed_transport_client.send_transaction.call_args_list
for call in calls:
call_args = call[0]
federation_transaction: Transaction = call_args[0]
@@ -431,7 +449,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertTrue(found_update)
- def test_update_membership(self):
+ def test_update_membership(self) -> None:
"""Tests that the module API can update the membership of a user in a room."""
peter = self.register_user("peter", "hackme")
lesley = self.register_user("lesley", "hackme")
@@ -554,14 +572,14 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(res["displayname"], "simone")
self.assertIsNone(res["avatar_url"])
- def test_update_room_membership_remote_join(self):
+ def test_update_room_membership_remote_join(self) -> None:
"""Test that the module API can join a remote room."""
# Necessary to fake a remote join.
fake_stream_id = 1
mocked_remote_join = simple_async_mock(
return_value=("fake-event-id", fake_stream_id)
)
- self.hs.get_room_member_handler()._remote_join = mocked_remote_join
+ self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment]
fake_remote_host = f"{self.module_api.server_name}-remote"
# Given that the join is to be faked, we expect the relevant join event not to
@@ -582,7 +600,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that a remote join was attempted.
self.assertEqual(mocked_remote_join.call_count, 1)
- def test_get_room_state(self):
+ def test_get_room_state(self) -> None:
"""Tests that a module can retrieve the state of a room through the module API."""
user_id = self.register_user("peter", "hackme")
tok = self.login("peter", "hackme")
@@ -677,7 +695,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.check_push_rule_actions(["foo"])
with self.assertRaises(InvalidRuleException):
- self.module_api.check_push_rule_actions({"foo": "bar"})
+ self.module_api.check_push_rule_actions([{"foo": "bar"}])
self.module_api.check_push_rule_actions(["notify"])
@@ -756,7 +774,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertIsNone(room_alias)
-class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
+class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
servlets = [
@@ -766,7 +784,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
presence.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
conf = super().default_config()
conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = {
@@ -774,18 +792,18 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
}
return conf
- def prepare(self, reactor, clock, homeserver):
- self.module_api = homeserver.get_module_api()
- self.sync_handler = homeserver.get_sync_handler()
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.module_api = hs.get_module_api()
+ self.sync_handler = hs.get_sync_handler()
- def test_send_local_online_presence_to_workers(self):
+ def test_send_local_online_presence_to_workers(self) -> None:
# Test sending local online presence to users from a worker process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
def _test_sending_local_online_presence_to_local_user(
- test_case: HomeserverTestCase, test_with_workers: bool = False
-):
+ test_case: BaseModuleApiTestCase, test_with_workers: bool = False
+) -> None:
"""Tests that send_local_presence_to_users sends local online presence to local users.
This simultaneously tests two different usecases:
@@ -852,6 +870,7 @@ def _test_sending_local_online_presence_to_local_user(
# Replicate the current sync presence token from the main process to the worker process.
# We need to do this so that the worker process knows the current presence stream ID to
# insert into the database when we call ModuleApi.send_local_online_presence_to.
+ assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
test_case.replicate()
# Syncing again should result in no presence updates
@@ -868,6 +887,7 @@ def _test_sending_local_online_presence_to_local_user(
# Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
if test_with_workers:
+ assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
module_api_to_use = worker_hs.get_module_api()
else:
module_api_to_use = test_case.module_api
@@ -875,12 +895,11 @@ def _test_sending_local_online_presence_to_local_user(
# Trigger sending local online presence. We expect this information
# to be saved to the database where all processes can access it.
# Note that we're syncing via the master.
- d = module_api_to_use.send_local_online_presence_to(
- [
- test_case.presence_receiver_id,
- ]
+ d = defer.ensureDeferred(
+ module_api_to_use.send_local_online_presence_to(
+ [test_case.presence_receiver_id],
+ )
)
- d = defer.ensureDeferred(d)
if test_with_workers:
# In order for the required presence_set_state replication request to occur between the
@@ -897,7 +916,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
- presence_update: UserPresenceState = presence_updates[0]
+ presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@@ -908,7 +927,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
- presence_update: UserPresenceState = presence_updates[0]
+ presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@@ -936,12 +955,13 @@ def _test_sending_local_online_presence_to_local_user(
test_case.assertEqual(len(presence_updates), 1)
# Now trigger sending local online presence.
- d = module_api_to_use.send_local_online_presence_to(
- [
- test_case.presence_receiver_id,
- ]
+ d = defer.ensureDeferred(
+ module_api_to_use.send_local_online_presence_to(
+ [
+ test_case.presence_receiver_id,
+ ]
+ )
)
- d = defer.ensureDeferred(d)
if test_with_workers:
# In order for the required presence_set_state replication request to occur between the
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 7567756135..46df0102f7 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -33,7 +33,6 @@ from tests.unittest import HomeserverTestCase, override_config
class TestBulkPushRuleEvaluator(HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -131,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Create a new message event, and try to evaluate it under the dodgy
# power level event.
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -146,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
prev_event_ids=[pl_event_id],
)
)
+ context = self.get_success(unpersisted_context.persist(event))
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# should not raise
@@ -171,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
"""Ensure that push rules are not calculated when disabled in the config"""
# Create a new message event which should cause a notification.
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -185,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
+ context = self.get_success(unpersisted_context.persist(event))
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Mock the method which calculates push rules -- we do this instead of
@@ -201,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
) -> bool:
"""Returns true iff the `mentions` trigger an event push action."""
# Create a new message event which should cause a notification.
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -212,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
-
+ context = self.get_success(unpersisted_context.persist(event))
# Execute the push rule machinery.
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
@@ -377,7 +378,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Create & persist an event to use as the parent of the relation.
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -391,6 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
self.event_creation_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index ab8bb417e7..4ea5472eb4 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
+from synapse.push.emailpusher import EmailPusher
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -38,7 +39,6 @@ class _User:
class EmailPusherTests(HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -47,7 +47,6 @@ class EmailPusherTests(HomeserverTestCase):
hijack_auth = False
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["email"] = {
"enable_notifs": True,
@@ -105,6 +104,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(self.access_token)
)
+ assert user_tuple is not None
self.token_id = user_tuple.token_id
# We need to add email to account before we can create a pusher.
@@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
)
)
- self.pusher = self.get_success(
+ pusher = self.get_success(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id,
access_token=self.token_id,
@@ -127,6 +127,8 @@ class EmailPusherTests(HomeserverTestCase):
data={},
)
)
+ assert isinstance(pusher, EmailPusher)
+ self.pusher = pusher
self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastores().main
@@ -367,18 +369,19 @@ class EmailPusherTests(HomeserverTestCase):
# disassociate the user's email address
self.get_success(
- self.auth_handler.delete_threepid(
- user_id=self.user_id,
- medium="email",
- address="a@example.com",
+ self.auth_handler.delete_local_threepid(
+ user_id=self.user_id, medium="email", address="a@example.com"
)
)
# check that the pusher for that email address has been deleted
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 0)
def test_remove_unlinked_pushers_background_job(self) -> None:
@@ -413,10 +416,13 @@ class EmailPusherTests(HomeserverTestCase):
self.wait_for_background_updates()
# Check that all pushers with unlinked addresses were deleted
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 0)
def _check_for_mail(self) -> Tuple[Sequence, Dict]:
@@ -428,10 +434,13 @@ class EmailPusherTests(HomeserverTestCase):
that notification.
"""
# Get the stream ordering before it gets sent
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -439,10 +448,13 @@ class EmailPusherTests(HomeserverTestCase):
self.pump(10)
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
@@ -458,10 +470,13 @@ class EmailPusherTests(HomeserverTestCase):
self.assertEqual(len(self.email_attempts), 1)
# The stream ordering has increased
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 23447cc310..c280ddcdf6 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Tuple
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -22,7 +22,6 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer
-from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import JsonDict
from synapse.util import Clock
@@ -67,9 +66,10 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
- def test_data(data: Optional[JsonDict]) -> None:
+ def test_data(data: Any) -> None:
self.get_failure(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
@@ -113,6 +113,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -140,10 +141,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.helper.send(room, body="There!", tok=other_access_token)
# Get the stream ordering before it gets sent
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -151,10 +153,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
@@ -172,10 +175,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# The stream ordering has increased
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -194,10 +198,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# The stream ordering has increased, again
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
@@ -229,6 +234,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -349,6 +355,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -435,6 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -512,6 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -618,6 +627,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -753,6 +763,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -895,6 +906,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
device_id = user_tuple.device_id
@@ -941,9 +953,10 @@ class HTTPPusherTests(HomeserverTestCase):
)
# Look up the user info for the access token so we can compare the device ID.
- lookup_result: TokenLookupResult = self.get_success(
+ lookup_result = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert lookup_result is not None
# Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token)
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 7c430c4ecb..52c4aafea6 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Set, Union, cast
+from typing import Any, Dict, List, Optional, Union, cast
import frozendict
@@ -22,7 +22,7 @@ import synapse.rest.admin
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.room_versions import RoomVersions
from synapse.appservice import ApplicationService
-from synapse.events import FrozenEvent
+from synapse.events import FrozenEvent, make_event_from_dict
from synapse.push.bulk_push_rule_evaluator import _flatten_dict
from synapse.push.httppusher import tweaks_for_actions
from synapse.rest import admin
@@ -32,6 +32,7 @@ from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.synapse_rust.push import PushRuleEvaluator
from synapse.types import JsonDict, JsonMapping, UserID
from synapse.util import Clock
+from synapse.util.frozenutils import freeze
from tests import unittest
from tests.test_utils.event_injection import create_event, inject_member_event
@@ -48,17 +49,93 @@ class FlattenDictTestCase(unittest.TestCase):
input = {"foo": {"bar": "abc"}}
self.assertEqual({"foo.bar": "abc"}, _flatten_dict(input))
+ # If a field has a dot in it, escape it.
+ input = {"m.foo": {"b\\ar": "abc"}}
+ self.assertEqual({"m\\.foo.b\\\\ar": "abc"}, _flatten_dict(input))
+
def test_non_string(self) -> None:
- """Non-string items are dropped."""
+ """String, booleans, ints, nulls and list of those should be kept while other items are dropped."""
input: Dict[str, Any] = {
"woo": "woo",
"foo": True,
"bar": 1,
"baz": None,
- "fuzz": [],
+ "fuzz": ["woo", True, 1, None, [], {}],
"boo": {},
}
- self.assertEqual({"woo": "woo"}, _flatten_dict(input))
+ self.assertEqual(
+ {
+ "woo": "woo",
+ "foo": True,
+ "bar": 1,
+ "baz": None,
+ "fuzz": ["woo", True, 1, None],
+ },
+ _flatten_dict(input),
+ )
+
+ def test_event(self) -> None:
+ """Events can also be flattened."""
+ event = make_event_from_dict(
+ {
+ "room_id": "!test:test",
+ "type": "m.room.message",
+ "sender": "@alice:test",
+ "content": {
+ "msgtype": "m.text",
+ "body": "Hello world!",
+ "format": "org.matrix.custom.html",
+ "formatted_body": "<h1>Hello world!</h1>",
+ },
+ },
+ room_version=RoomVersions.V8,
+ )
+ expected = {
+ "content.msgtype": "m.text",
+ "content.body": "Hello world!",
+ "content.format": "org.matrix.custom.html",
+ "content.formatted_body": "<h1>Hello world!</h1>",
+ "room_id": "!test:test",
+ "sender": "@alice:test",
+ "type": "m.room.message",
+ }
+ self.assertEqual(expected, _flatten_dict(event))
+
+ def test_extensible_events(self) -> None:
+ """Extensible events has compatibility behaviour."""
+ event_dict = {
+ "room_id": "!test:test",
+ "type": "m.room.message",
+ "sender": "@alice:test",
+ "content": {
+ "org.matrix.msc1767.markup": [
+ {"mimetype": "text/plain", "body": "Hello world!"},
+ {"mimetype": "text/html", "body": "<h1>Hello world!</h1>"},
+ ]
+ },
+ }
+
+ # For a current room version, there's no special behavior.
+ event = make_event_from_dict(event_dict, room_version=RoomVersions.V8)
+ expected = {
+ "room_id": "!test:test",
+ "sender": "@alice:test",
+ "type": "m.room.message",
+ "content.org\\.matrix\\.msc1767\\.markup": [],
+ }
+ self.assertEqual(expected, _flatten_dict(event))
+
+ # For a room version with extensible events, they parse out the text/plain
+ # to a content.body property.
+ event = make_event_from_dict(event_dict, room_version=RoomVersions.MSC1767v10)
+ expected = {
+ "content.body": "hello world!",
+ "room_id": "!test:test",
+ "sender": "@alice:test",
+ "type": "m.room.message",
+ "content.org\\.matrix\\.msc1767\\.markup": [],
+ }
+ self.assertEqual(expected, _flatten_dict(event))
class PushRuleEvaluatorTestCase(unittest.TestCase):
@@ -66,9 +143,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
self,
content: JsonMapping,
*,
- has_mentions: bool = False,
- user_mentions: Optional[Set[str]] = None,
- room_mention: bool = False,
related_events: Optional[JsonDict] = None,
) -> PushRuleEvaluator:
event = FrozenEvent(
@@ -87,9 +161,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluator(
_flatten_dict(event),
- has_mentions,
- user_mentions or set(),
- room_mention,
+ False,
room_member_count,
sender_power_level,
cast(Dict[str, int], power_levels.get("notifications", {})),
@@ -123,53 +195,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
# A display name with spaces should work fine.
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
- def test_user_mentions(self) -> None:
- """Check for user mentions."""
- condition = {"kind": "org.matrix.msc3952.is_user_mention"}
-
- # No mentions shouldn't match.
- evaluator = self._get_evaluator({}, has_mentions=True)
- self.assertFalse(evaluator.matches(condition, "@user:test", None))
-
- # An empty set shouldn't match
- evaluator = self._get_evaluator({}, has_mentions=True, user_mentions=set())
- self.assertFalse(evaluator.matches(condition, "@user:test", None))
-
- # The Matrix ID appearing anywhere in the mentions list should match
- evaluator = self._get_evaluator(
- {}, has_mentions=True, user_mentions={"@user:test"}
- )
- self.assertTrue(evaluator.matches(condition, "@user:test", None))
-
- evaluator = self._get_evaluator(
- {}, has_mentions=True, user_mentions={"@another:test", "@user:test"}
- )
- self.assertTrue(evaluator.matches(condition, "@user:test", None))
-
- # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
- # since the BulkPushRuleEvaluator is what handles data sanitisation.
-
- def test_room_mentions(self) -> None:
- """Check for room mentions."""
- condition = {"kind": "org.matrix.msc3952.is_room_mention"}
-
- # No room mention shouldn't match.
- evaluator = self._get_evaluator({}, has_mentions=True)
- self.assertFalse(evaluator.matches(condition, None, None))
-
- # Room mention should match.
- evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True)
- self.assertTrue(evaluator.matches(condition, None, None))
-
- # A room mention and user mention is valid.
- evaluator = self._get_evaluator(
- {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True
- )
- self.assertTrue(evaluator.matches(condition, None, None))
-
- # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
- # since the BulkPushRuleEvaluator is what handles data sanitisation.
-
def _assert_matches(
self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
) -> None:
@@ -341,6 +366,193 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"pattern should not match before a newline",
)
+ def test_event_match_pattern(self) -> None:
+ """Check that event_match conditions do not use a "pattern_type" from user data."""
+
+ # The pattern_type should not be deserialized into anything valid.
+ condition = {
+ "kind": "event_match",
+ "key": "content.value",
+ "pattern_type": "user_id",
+ }
+ self._assert_not_matches(
+ condition,
+ {"value": "@user:test"},
+ "should not be possible to pass a pattern_type in",
+ )
+
+ # This is an internal-only condition which shouldn't get deserialized.
+ condition = {
+ "kind": "event_match_type",
+ "key": "content.value",
+ "pattern_type": "user_id",
+ }
+ self._assert_not_matches(
+ condition,
+ {"value": "@user:test"},
+ "should not be possible to pass a pattern_type in",
+ )
+
+ def test_exact_event_match_string(self) -> None:
+ """Check that exact_event_match conditions work as expected for strings."""
+
+ # Test against a string value.
+ condition = {
+ "kind": "event_property_is",
+ "key": "content.value",
+ "value": "foobaz",
+ }
+ self._assert_matches(
+ condition,
+ {"value": "foobaz"},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": "FoobaZ"},
+ "values should match and be case-sensitive",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": "test foobaz test"},
+ "values must exactly match",
+ )
+ value: Any
+ for value in (True, False, 1, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ # it should work on frozendicts too
+ self._assert_matches(
+ condition,
+ frozendict.frozendict({"value": "foobaz"}),
+ "values should match on frozendicts",
+ )
+
+ def test_exact_event_match_boolean(self) -> None:
+ """Check that exact_event_match conditions work as expected for booleans."""
+
+ # Test against a True boolean value.
+ condition = {"kind": "event_property_is", "key": "content.value", "value": True}
+ self._assert_matches(
+ condition,
+ {"value": True},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": False},
+ "incorrect values should not match",
+ )
+ for value in ("foobaz", 1, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ # Test against a False boolean value.
+ condition = {
+ "kind": "event_property_is",
+ "key": "content.value",
+ "value": False,
+ }
+ self._assert_matches(
+ condition,
+ {"value": False},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": True},
+ "incorrect values should not match",
+ )
+ # Choose false-y values to ensure there's no type coercion.
+ for value in ("", 0, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ def test_exact_event_match_null(self) -> None:
+ """Check that exact_event_match conditions work as expected for null."""
+
+ condition = {"kind": "event_property_is", "key": "content.value", "value": None}
+ self._assert_matches(
+ condition,
+ {"value": None},
+ "exact value should match",
+ )
+ for value in ("foobaz", True, False, 1, 1.1, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ def test_exact_event_match_integer(self) -> None:
+ """Check that exact_event_match conditions work as expected for integers."""
+
+ condition = {"kind": "event_property_is", "key": "content.value", "value": 1}
+ self._assert_matches(
+ condition,
+ {"value": 1},
+ "exact value should match",
+ )
+ value: Any
+ for value in (1.1, -1, 0):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect values should not match",
+ )
+ for value in ("1", True, False, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ def test_exact_event_property_contains(self) -> None:
+ """Check that exact_event_property_contains conditions work as expected."""
+
+ condition = {
+ "kind": "event_property_contains",
+ "key": "content.value",
+ "value": "foobaz",
+ }
+ self._assert_matches(
+ condition,
+ {"value": ["foobaz"]},
+ "exact value should match",
+ )
+ self._assert_matches(
+ condition,
+ {"value": ["foobaz", "bugz"]},
+ "extra values should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": ["FoobaZ"]},
+ "values should match and be case-sensitive",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": "foobaz"},
+ "does not search in a string",
+ )
+
+ # it should work on frozendicts too
+ self._assert_matches(
+ condition,
+ freeze({"value": ["foobaz"]}),
+ "values should match on frozendicts",
+ )
+
def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({})
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 6a7174b333..46a8e2013e 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -16,7 +16,9 @@ from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address
-from twisted.internet.protocol import Protocol
+from twisted.internet.protocol import Protocol, connectionDone
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
@@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
@@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
@@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
repl_handler,
)
- self._client_transport = None
- self._server_transport = None
+ self._client_transport: Optional[FakeTransport] = None
+ self._server_transport: Optional[FakeTransport] = None
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
@@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
- def _build_replication_data_handler(self):
+ def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
return TestReplicationDataHandler(self.worker_hs)
- def reconnect(self):
+ def reconnect(self) -> None:
if self._client_transport:
self.client.close()
@@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)
- def disconnect(self):
+ def disconnect(self) -> None:
if self._client_transport:
self._client_transport = None
self.client.close()
@@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = None
self.server.close()
- def replicate(self):
+ def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
@@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
requests: List[SynapseRequest] = []
real_request_factory = channel.requestFactory
- def request_factory(*args, **kwargs):
+ def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
request = real_request_factory(*args, **kwargs)
requests.append(request)
return request
@@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
- ):
+ ) -> None:
"""Asserts that the given request is a HTTP replication request for
fetching updates for given stream.
"""
@@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
base["redis"] = {"enabled": True}
return base
- def setUp(self):
+ def setUp(self) -> None:
super().setUp()
# build a replication server
@@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
- def create_test_resource(self):
+ def create_test_resource(self) -> ReplicationRestResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
@@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return resource
def make_worker_hs(
- self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
+ self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
@@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
- def replicate(self):
+ def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()
- def _handle_http_replication_attempt(self, hs, repl_port):
+ def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
"""Handles a connection attempt to the given HS replication HTTP
listener on the given port.
"""
@@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
- def connect_any_redis_attempts(self):
+ def connect_any_redis_attempts(self) -> None:
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
@@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
- client_protocol = client_factory.buildProtocol(None)
- server_protocol = self._redis_server.buildProtocol(None)
+ client_address = IPv4Address("TCP", "127.0.0.1", 6379)
+ client_protocol = client_factory.buildProtocol(client_address)
+
+ server_address = IPv4Address("TCP", host, port)
+ server_protocol = self._redis_server.buildProtocol(server_address)
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
@@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
# list of received (stream_name, token, row) tuples
self.received_rdata_rows: List[Tuple[str, int, Any]] = []
- async def on_rdata(self, stream_name, instance_name, token, rows):
+ async def on_rdata(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ) -> None:
await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
@@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
- def __init__(self):
+ def __init__(self) -> None:
self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)
- def add_subscriber(self, conn, channel: bytes):
+ def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE"""
self._subscribers_by_channel[channel].add(conn)
- def remove_subscriber(self, conn):
+ def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)
- def publish(self, conn, channel: bytes, msg) -> int:
+ def publish(
+ self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
+ ) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])
return len(self._subscribers_by_channel)
- def buildProtocol(self, addr):
+ def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
return FakeRedisPubSubProtocol(self)
@@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
self._server = server
self._reader = hiredis.Reader()
- def dataReceived(self, data):
+ def dataReceived(self, data: bytes) -> None:
self._reader.feed(data)
# We might get multiple messages in one packet.
@@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.handle_command(msg[0], *msg[1:])
- def handle_command(self, command, *args):
+ def handle_command(self, command: bytes, *args: bytes) -> None:
"""Received a Redis command from the client."""
# We currently only support pub/sub.
@@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
self.send("PONG")
else:
- raise Exception(f"Unknown command: {command}")
+ raise Exception(f"Unknown command: {command!r}")
- def send(self, msg):
+ def send(self, msg: object) -> None:
"""Send a message back to the client."""
assert self.transport is not None
@@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.transport.write(raw)
self.transport.flush()
- def encode(self, obj):
+ def encode(self, obj: object) -> str:
"""Encode an object to its Redis format.
Supports: strings/bytes, integers and list/tuples.
@@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
- def connectionLost(self, reason):
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
self._server.remove_subscriber(self)
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index e03d9b4cc0..9be11ab802 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -74,7 +74,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""
- def create_test_resource(self):
+ def create_test_resource(self) -> JsonResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
resource = JsonResource(self.hs)
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index c5705256e6..4c9b494344 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -13,35 +13,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Iterable, Optional
from unittest.mock import Mock
-from tests.replication._base import BaseStreamTestCase
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.server import HomeServer
+from synapse.util import Clock
-class BaseSlavedStoreTestCase(BaseStreamTestCase):
- def make_homeserver(self, reactor, clock):
+from tests.replication._base import BaseStreamTestCase
- hs = self.setup_test_homeserver(federation_client=Mock())
- return hs
+class BaseSlavedStoreTestCase(BaseStreamTestCase):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ return self.setup_test_homeserver(federation_client=Mock())
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.reconnect()
self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
- self._storage_controllers = hs.get_storage_controllers()
+ persistence = hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistance = persistence
- def replicate(self):
+ def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump(0.1)
- def check(self, method, args, expected_result=None):
+ def check(
+ self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
+ ) -> None:
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index dce71f7334..57c781a0c3 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Iterable, Optional
+from typing import Any, Callable, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
+from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
+from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import (
NotifCounts,
RoomNotifCounts,
@@ -28,6 +32,7 @@ from synapse.storage.databases.main.event_push_actions import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
+from synapse.util import Clock
from tests.server import FakeTransport
@@ -41,35 +46,34 @@ ROOM_ID = "!room:test"
logger = logging.getLogger(__name__)
-def dict_equals(self, other):
+def dict_equals(self: EventBase, other: EventBase) -> bool:
me = encode_canonical_json(self.get_pdu_json())
them = encode_canonical_json(other.get_pdu_json())
return me == them
-def patch__eq__(cls):
+def patch__eq__(cls: object) -> Callable[[], None]:
eq = getattr(cls, "__eq__", None)
- cls.__eq__ = dict_equals
+ cls.__eq__ = dict_equals # type: ignore[assignment]
- def unpatch():
+ def unpatch() -> None:
if eq is not None:
- cls.__eq__ = eq
+ cls.__eq__ = eq # type: ignore[assignment]
return unpatch
class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
-
STORE_TYPE = EventsWorkerStore
- def setUp(self):
+ def setUp(self) -> None:
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEqual
- self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
- return super().setUp()
+ self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
+ super().setUp()
- def prepare(self, *args, **kwargs):
- super().prepare(*args, **kwargs)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ super().prepare(reactor, clock, hs)
self.get_success(
self.master_store.store_room(
@@ -80,10 +84,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
)
- def tearDown(self):
+ def tearDown(self) -> None:
[unpatch() for unpatch in self.unpatches]
- def test_get_latest_event_ids_in_room(self):
+ def test_get_latest_event_ids_in_room(self) -> None:
create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
@@ -97,7 +101,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
- def test_redactions(self):
+ def test_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
@@ -117,7 +121,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
self.check("get_event", [msg.event_id], redacted)
- def test_backfilled_redactions(self):
+ def test_backfilled_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
@@ -139,7 +143,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
self.check("get_event", [msg.event_id], redacted)
- def test_invites(self):
+ def test_invites(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
@@ -163,7 +167,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
@parameterized.expand([(True,), (False,)])
- def test_push_actions_for_user(self, send_receipt: bool):
+ def test_push_actions_for_user(self, send_receipt: bool) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist(
@@ -219,7 +223,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
),
)
- def test_get_rooms_for_user_with_stream_ordering(self):
+ def test_get_rooms_for_user_with_stream_ordering(self) -> None:
"""Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
by rows in the events stream
"""
@@ -243,7 +247,9 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
{GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
)
- def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
+ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
+ self,
+ ) -> None:
"""Check that current_state invalidation happens correctly with multiple events
in the persistence batch.
@@ -283,11 +289,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
- self.get_success(
- self._storage_controllers.persistence.persist_events(
- [(j2, j2ctx), (msg, msgctx)]
- )
- )
+ self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
self.replicate()
assert j2.internal_metadata.stream_ordering is not None
@@ -339,7 +341,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
event_id = 0
- def persist(self, backfill=False, **kwargs) -> FrozenEvent:
+ def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
"""
Returns:
The event that was persisted.
@@ -348,32 +350,28 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
- self._storage_controllers.persistence.persist_events(
- [(event, context)], backfilled=True
- )
+ self.persistance.persist_events([(event, context)], backfilled=True)
)
else:
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ self.get_success(self.persistance.persist_event(event, context))
return event
def build_event(
self,
- sender=USER_ID,
- room_id=ROOM_ID,
- type="m.room.message",
- key=None,
+ sender: str = USER_ID,
+ room_id: str = ROOM_ID,
+ type: str = "m.room.message",
+ key: Optional[str] = None,
internal: Optional[dict] = None,
- depth=None,
- prev_events: Optional[list] = None,
- auth_events: Optional[list] = None,
- prev_state: Optional[list] = None,
- redacts=None,
+ depth: Optional[int] = None,
+ prev_events: Optional[List[Tuple[str, dict]]] = None,
+ auth_events: Optional[List[str]] = None,
+ prev_state: Optional[List[str]] = None,
+ redacts: Optional[str] = None,
push_actions: Iterable = frozenset(),
- **content,
- ):
+ **content: object,
+ ) -> Tuple[EventBase, EventContext]:
prev_events = prev_events or []
auth_events = auth_events or []
prev_state = prev_state or []
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 50fbff5f32..01df1be047 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase
class AccountDataStreamTestCase(BaseStreamTestCase):
- def test_update_function_room_account_data_limit(self):
+ def test_update_function_room_account_data_limit(self) -> None:
"""Test replication with many room account data updates"""
store = self.hs.get_datastores().main
@@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
- def test_update_function_global_account_data_limit(self):
+ def test_update_function_global_account_data_limit(self) -> None:
"""Test replication with many global account data updates"""
store = self.hs.get_datastores().main
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 641a94133b..65ef4bb160 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional
+from typing import Any, List, Optional, Sequence
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import (
)
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.replication._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
@@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
@@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
- def test_update_function_event_row_limit(self):
+ def test_update_function_event_row_limit(self) -> None:
"""Test replication with many non-state events
Checks that all events are correctly replicated when there are lots of
@@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
- def test_update_function_huge_state_change(self):
+ def test_update_function_huge_state_change(self) -> None:
"""Test replication with many state events
Ensures that all events are correctly replicated when there are lots of
@@ -135,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point: List[str] = self.get_success(
+ fork_point: Sequence[str] = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)
@@ -164,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
pl_event = self.get_success(
inject_event(
self.hs,
- prev_event_ids=prev_events,
+ prev_event_ids=list(prev_events),
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
@@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
- def test_update_function_state_row_limit(self):
+ def test_update_function_state_row_limit(self) -> None:
"""Test replication with many state events over several stream ids."""
# we want to generate lots of state changes, but for this test, we want to
@@ -290,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point: List[str] = self.get_success(
+ fork_point: Sequence[str] = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)
@@ -319,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
e = self.get_success(
inject_event(
self.hs,
- prev_event_ids=prev_events,
+ prev_event_ids=list(prev_events),
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
@@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
- def test_backwards_stream_id(self):
+ def test_backwards_stream_id(self) -> None:
"""
Test that RDATA that comes after the current position should be discarded.
"""
@@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
event_count = 0
def _inject_test_event(
- self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
+ self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
) -> EventBase:
if sender is None:
sender = self.user_id
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
index bcb82c9c80..cdbdfaf057 100644
--- a/tests/replication/tcp/streams/test_federation.py
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
config["federation_sender_instances"] = ["federation_sender1"]
return config
- def test_catchup(self):
+ def test_catchup(self) -> None:
"""Basic test of catchup on reconnect
Makes sure that updates sent while we are offline are received later.
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index 2c10eab4db..452ac85069 100644
--- a/tests/replication/tcp/streams/test_partial_state.py
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
hijack_auth = True
user_id = "@bob:test"
- def setUp(self):
+ def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main
@@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
room_id = self.helper.create_room_as("@bob:test")
# Mark the room as partial-stated.
self.get_success(
- self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1")
+ self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1")
)
worker = self.make_worker_hs("synapse.app.generic_worker")
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 9a229dd23f..5a38ac831f 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -13,7 +13,7 @@
# limitations under the License.
from unittest.mock import Mock
-from synapse.handlers.typing import RoomMember
+from synapse.handlers.typing import RoomMember, TypingWriterHandler
from synapse.replication.tcp.streams import TypingStream
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -27,11 +27,13 @@ ROOM_ID_2 = "!foo:blue"
class TypingStreamTestCase(BaseStreamTestCase):
- def _build_replication_data_handler(self):
- return Mock(wraps=super()._build_replication_data_handler())
+ def _build_replication_data_handler(self) -> Mock:
+ self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
+ return self.mock_handler
- def test_typing(self):
+ def test_typing(self) -> None:
typing = self.hs.get_typing_handler()
+ assert isinstance(typing, TypingWriterHandler)
self.reconnect()
@@ -43,8 +45,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
- self.test_handler.on_rdata.assert_called_once()
- stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.mock_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -54,11 +56,11 @@ class TypingStreamTestCase(BaseStreamTestCase):
# Now let's disconnect and insert some data.
self.disconnect()
- self.test_handler.on_rdata.reset_mock()
+ self.mock_handler.on_rdata.reset_mock()
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
- self.test_handler.on_rdata.assert_not_called()
+ self.mock_handler.on_rdata.assert_not_called()
self.reconnect()
self.pump(0.1)
@@ -71,15 +73,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
assert request.args is not None
self.assertEqual(int(request.args[b"from_token"][0]), token)
- self.test_handler.on_rdata.assert_called_once()
- stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.mock_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([], row.user_ids)
- def test_reset(self):
+ def test_reset(self) -> None:
"""
Test what happens when a typing stream resets.
@@ -87,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
sends the proper position and RDATA).
"""
typing = self.hs.get_typing_handler()
+ assert isinstance(typing, TypingWriterHandler)
self.reconnect()
@@ -98,8 +101,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
- self.test_handler.on_rdata.assert_called_once()
- stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.mock_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -134,15 +137,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assert_request_is_get_repl_stream_updates(request, "typing")
# Reset the test code.
- self.test_handler.on_rdata.reset_mock()
- self.test_handler.on_rdata.assert_not_called()
+ self.mock_handler.on_rdata.reset_mock()
+ self.mock_handler.on_rdata.assert_not_called()
# Push additional data.
typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
self.reactor.advance(0)
- self.test_handler.on_rdata.assert_called_once()
- stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.mock_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index cca7ebb719..5d6b72b16d 100644
--- a/tests/replication/tcp/test_commands.py
+++ b/tests/replication/tcp/test_commands.py
@@ -21,12 +21,12 @@ from tests.unittest import TestCase
class ParseCommandTestCase(TestCase):
- def test_parse_one_word_command(self):
+ def test_parse_one_word_command(self) -> None:
line = "REPLICATE"
cmd = parse_command_from_line(line)
self.assertIsInstance(cmd, ReplicateCommand)
- def test_parse_rdata(self):
+ def test_parse_rdata(self) -> None:
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
cmd = parse_command_from_line(line)
assert isinstance(cmd, RdataCommand)
@@ -34,7 +34,7 @@ class ParseCommandTestCase(TestCase):
self.assertEqual(cmd.instance_name, "master")
self.assertEqual(cmd.token, 6287863)
- def test_parse_rdata_batch(self):
+ def test_parse_rdata_batch(self) -> None:
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
cmd = parse_command_from_line(line)
assert isinstance(cmd, RdataCommand)
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index 6e4055cc21..bab77b2df7 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -127,6 +127,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
# ... updating the cache ID gen on the master still shouldn't cause the
# deferred to wake up.
+ assert store._cache_id_gen is not None
ctx = store._cache_id_gen.get_next()
self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None))
@@ -140,3 +141,64 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
self.get_success(ctx_worker1.__aexit__(None, None, None))
self.assertTrue(d.called)
+
+ def test_wait_for_stream_position_rdata(self) -> None:
+ """Check that wait for stream position correctly waits for an update
+ from the correct instance, when RDATA is sent.
+ """
+ store = self.hs.get_datastores().main
+ cmd_handler = self.hs.get_replication_command_handler()
+ data_handler = self.hs.get_replication_data_handler()
+
+ worker1 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker1",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+
+ cache_id_gen = worker1.get_datastores().main._cache_id_gen
+ assert cache_id_gen is not None
+
+ self.replicate()
+
+ # First, make sure the master knows that `worker1` exists.
+ initial_token = cache_id_gen.get_current_token()
+ cmd_handler.send_command(
+ PositionCommand("caches", "worker1", initial_token, initial_token)
+ )
+ self.replicate()
+
+ # `wait_for_stream_position` should only return once master receives a
+ # notification that `next_token2` has persisted.
+ ctx_worker1 = cache_id_gen.get_next_mult(2)
+ next_token1, next_token2 = self.get_success(ctx_worker1.__aenter__())
+
+ d = defer.ensureDeferred(
+ data_handler.wait_for_stream_position("worker1", "caches", next_token2)
+ )
+ self.assertFalse(d.called)
+
+ # Insert an entry into the cache stream with token `next_token1`, but
+ # not `next_token2`.
+ self.get_success(
+ store.db_pool.simple_insert(
+ table="cache_invalidation_stream_by_instance",
+ values={
+ "stream_id": next_token1,
+ "instance_name": "worker1",
+ "cache_func": "foo",
+ "keys": [],
+ "invalidation_ts": 0,
+ },
+ )
+ )
+
+ # Finish the context manager, triggering the data to be sent to master.
+ self.get_success(ctx_worker1.__aexit__(None, None, None))
+
+ # Master should get told about `next_token2`, so the deferred should
+ # resolve.
+ self.assertTrue(d.called)
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
index 545f11acd1..b75fc05fd5 100644
--- a/tests/replication/tcp/test_remote_server_up.py
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -16,15 +16,17 @@ from typing import Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IProtocol
-from twisted.test.proto_helpers import StringTransport
+from twisted.test.proto_helpers import MemoryReactor, StringTransport
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class RemoteServerUpTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.factory = ReplicationStreamProtocolFactory(hs)
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
@@ -40,7 +42,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
return proto, transport
- def test_relay(self):
+ def test_relay(self) -> None:
"""Test that Synapse will relay REMOTE_SERVER_UP commands to all
other connections, but not the one that sent it.
"""
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index 5d7a89e0c7..98602371e4 100644
--- a/tests/replication/test_auth.py
+++ b/tests/replication/test_auth.py
@@ -13,7 +13,11 @@
# limitations under the License.
import logging
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.rest.client import register
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, make_request
@@ -27,7 +31,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [register.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# This isn't a real configuration option but is used to provide the main
# homeserver and worker homeserver different options.
@@ -77,7 +81,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
{"auth": {"session": session, "type": "m.login.dummy"}},
)
- def test_no_auth(self):
+ def test_no_auth(self) -> None:
"""With no authentication the request should finish."""
channel = self._test_register()
self.assertEqual(channel.code, 200)
@@ -86,7 +90,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel.json_body["user_id"], "@user:test")
@override_config({"main_replication_secret": "my-secret"})
- def test_missing_auth(self):
+ def test_missing_auth(self) -> None:
"""If the main process expects a secret that is not provided, an error results."""
channel = self._test_register()
self.assertEqual(channel.code, 500)
@@ -97,13 +101,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
"worker_replication_secret": "wrong-secret",
}
)
- def test_unauthorized(self):
+ def test_unauthorized(self) -> None:
"""If the main process receives the wrong secret, an error results."""
channel = self._test_register()
self.assertEqual(channel.code, 500)
@override_config({"worker_replication_secret": "my-secret"})
- def test_authorized(self):
+ def test_authorized(self) -> None:
"""The request should finish when the worker provides the authentication header."""
channel = self._test_register()
self.assertEqual(channel.code, 200)
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index eb5b376534..eca5033761 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -33,7 +33,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
config["worker_replication_http_port"] = "8765"
return config
- def test_register_single_worker(self):
+ def test_register_single_worker(self) -> None:
"""Test that registration works when using a single generic worker."""
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
site = self._hs_to_site[worker_hs]
@@ -63,7 +63,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
# We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
- def test_register_multi_worker(self):
+ def test_register_multi_worker(self) -> None:
"""Test that registration works when using multiple generic workers."""
worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 63b1dd40b5..12668b34c5 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -14,10 +14,14 @@
from unittest import mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -30,12 +34,10 @@ class FederationAckTestCase(HomeserverTestCase):
config["federation_sender_instances"] = ["federation_sender1"]
return config
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
-
- return hs
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ return self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
- def test_federation_ack_sent(self):
+ def test_federation_ack_sent(self) -> None:
"""A FEDERATION_ACK should be sent back after each RDATA federation
This test checks that the federation sender is correctly sending back
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index c28073b8f7..08703206a9 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -16,6 +16,7 @@ from unittest.mock import Mock
from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
+from synapse.handlers.typing import TypingWriterHandler
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client import login, room
from synapse.types import UserID, create_requester
@@ -40,7 +41,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
room.register_servlets,
]
- def test_send_event_single_sender(self):
+ def test_send_event_single_sender(self) -> None:
"""Test that using a single federation sender worker correctly sends a
new event.
"""
@@ -71,7 +72,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
- def test_send_event_sharded(self):
+ def test_send_event_sharded(self) -> None:
"""Test that using two federation sender workers correctly sends
new events.
"""
@@ -138,7 +139,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2)
- def test_send_typing_sharded(self):
+ def test_send_typing_sharded(self) -> None:
"""Test that using two federation sender workers correctly sends
new typing EDUs.
"""
@@ -174,6 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
token = self.login("user3", "pass")
typing_handler = self.hs.get_typing_handler()
+ assert isinstance(typing_handler, TypingWriterHandler)
sent_on_1 = False
sent_on_2 = False
@@ -215,7 +217,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2)
- def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+ def create_room_with_remote_server(
+ self, user: str, token: str, remote_server: str = "other_server"
+ ) -> str:
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastores().main
federation = self.hs.get_federation_event_handler()
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
index b93cae67d3..9c4fbda71b 100644
--- a/tests/replication/test_module_cache_invalidation.py
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -39,7 +39,7 @@ class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
synapse.rest.admin.register_servlets,
]
- def test_module_cache_full_invalidation(self):
+ def test_module_cache_full_invalidation(self) -> None:
main_cache = TestCache()
self.hs.get_module_api().register_cached_function(main_cache.cached_function)
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 96cdf2c45b..1527b4a82d 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -18,12 +18,14 @@ from typing import Optional, Tuple
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
+from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http import HTTPChannel
from twisted.web.server import Request
from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer
+from synapse.util import Clock
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
self.reactor.lookups["example.com"] = "1.2.3.4"
- def default_config(self):
+ def default_config(self) -> dict:
conf = super().default_config()
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
return conf
@@ -122,7 +124,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return channel, request
- def test_basic(self):
+ def test_basic(self) -> None:
"""Test basic fetching of remote media from a single worker."""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
@@ -138,7 +140,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"Hello!")
- def test_download_simple_file_race(self):
+ def test_download_simple_file_race(self) -> None:
"""Test that fetching remote media from two different processes at the
same time works.
"""
@@ -177,7 +179,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
# We expect only one new file to have been persisted.
self.assertEqual(start_count + 1, self._count_remote_media())
- def test_download_image_race(self):
+ def test_download_image_race(self) -> None:
"""Test that fetching remote *images* from two different processes at
the same time works.
@@ -229,7 +231,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path))
-def get_connection_factory():
+def get_connection_factory() -> TestServerTLSConnectionFactory:
# this needs to happen once, but not until we are ready to run the first test
global test_server_connection_factory
if test_server_connection_factory is None:
@@ -263,6 +265,6 @@ def _build_test_server(
return server_tls_factory.buildProtocol(None)
-def _log_request(request):
+def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index ca18ad6553..0798b021c3 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -15,9 +15,12 @@ import logging
from unittest.mock import Mock
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -33,12 +36,12 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
- def _create_pusher_and_send_msg(self, localpart):
+ def _create_pusher_and_send_msg(self, localpart: str) -> str:
# Create a user that will get push notifications
user_id = self.register_user(localpart, "pass")
access_token = self.login(localpart, "pass")
@@ -47,6 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_dict is not None
token_id = user_dict.token_id
self.get_success(
@@ -79,7 +83,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
return event_id
- def test_send_push_single_worker(self):
+ def test_send_push_single_worker(self) -> None:
"""Test that registration works when using a pusher worker."""
http_client_mock = Mock(spec_set=["post_json_get_json"])
http_client_mock.post_json_get_json.side_effect = (
@@ -109,7 +113,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
],
)
- def test_send_push_multiple_workers(self):
+ def test_send_push_multiple_workers(self) -> None:
"""Test that registration works when using sharded pusher workers."""
http_client_mock1 = Mock(spec_set=["post_json_get_json"])
http_client_mock1.post_json_get_json.side_effect = (
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 541d390286..7f9cc67e73 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -14,9 +14,13 @@
import logging
from unittest.mock import patch
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.rest import admin
from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
@@ -34,7 +38,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
sync.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
@@ -42,7 +46,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.room_creator = self.hs.get_room_creation_handler()
self.store = hs.get_datastores().main
- def default_config(self):
+ def default_config(self) -> dict:
conf = super().default_config()
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = {
@@ -51,7 +55,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
}
return conf
- def _create_room(self, room_id: str, user_id: str, tok: str):
+ def _create_room(self, room_id: str, user_id: str, tok: str) -> None:
"""Create a room with given room_id"""
# We control the room ID generation by patching out the
@@ -62,7 +66,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
mock.side_effect = lambda: room_id
self.helper.create_room_as(user_id, tok=tok)
- def test_basic(self):
+ def test_basic(self) -> None:
"""Simple test to ensure that multiple rooms can be created and joined,
and that different rooms get handled by different instances.
"""
@@ -112,7 +116,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(persisted_on_1)
self.assertTrue(persisted_on_2)
- def test_vector_clock_token(self):
+ def test_vector_clock_token(self) -> None:
"""Tests that using a stream token with a vector clock component works
correctly with basic /sync and /messages usage.
"""
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 03f2112b07..aaa488bced 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -28,7 +28,6 @@ from tests import unittest
class DeviceRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -291,7 +290,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
class DevicesRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -415,7 +413,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 233eba3516..f189b07769 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -78,7 +78,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
Try to get an event report without authentication.
"""
- channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, {})
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -473,7 +473,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
Try to get event report without authentication.
"""
- channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, {})
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -599,3 +599,142 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
self.assertIn("room_id", content["event_json"])
self.assertIn("sender", content["event_json"])
self.assertIn("content", content["event_json"])
+
+
+class DeleteEventReportTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self._store = hs.get_datastores().main
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ # create report
+ event_id = self.get_success(
+ self._store.add_event_report(
+ "room_id",
+ "event_id",
+ self.other_user,
+ "this makes me sad",
+ {},
+ self.clock.time_msec(),
+ )
+ )
+
+ self.url = f"/_synapse/admin/v1/event_reports/{event_id}"
+
+ def test_no_auth(self) -> None:
+ """
+ Try to delete event report without authentication.
+ """
+ channel = self.make_request("DELETE", self.url)
+
+ self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self) -> None:
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.other_user_tok,
+ )
+
+ self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_delete_success(self) -> None:
+ """
+ Testing delete a report.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual({}, channel.json_body)
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ # check that report was deleted
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_invalid_report_id(self) -> None:
+ """
+ Testing that an invalid `report_id` returns a 400.
+ """
+
+ # `report_id` is negative
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/-123",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ # `report_id` is a non-numerical string
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/abcdef",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ # `report_id` is undefined
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ def test_report_id_not_found(self) -> None:
+ """
+ Testing that a not existing `report_id` returns a 404.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/123",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+ self.assertEqual("Event report not found", channel.json_body["error"])
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index aadb31ca83..6d04911d67 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -20,8 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
+from synapse.media.filepath import MediaFilePaths
from synapse.rest.client import login, profile, room
-from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.server import HomeServer
from synapse.util import Clock
@@ -34,7 +34,6 @@ INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -196,7 +195,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -213,7 +211,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.admin_user_tok = self.login("admin", "pass")
self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
- self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
+ self.url = "/_synapse/admin/v1/media/delete"
+ self.legacy_url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
# Move clock up to somewhat realistic time
self.reactor.advance(1000000000)
@@ -332,11 +331,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
)
- def test_delete_media_never_accessed(self) -> None:
+ @parameterized.expand([(True,), (False,)])
+ def test_delete_media_never_accessed(self, use_legacy_url: bool) -> None:
"""
Tests that media deleted if it is older than `before_ts` and never accessed
`last_access_ts` is `NULL` and `created_ts` < `before_ts`
"""
+ url = self.legacy_url if use_legacy_url else self.url
# upload and do not access
server_and_media_id = self._create_media()
@@ -351,7 +352,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
now_ms = self.clock.time_msec()
channel = self.make_request(
"POST",
- self.url + "?before_ts=" + str(now_ms),
+ url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -591,7 +592,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -721,7 +721,6 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -818,7 +817,6 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 453a6e979c..9dbb778679 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1990,7 +1990,6 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
room.register_servlets,
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index a2f347f666..28b999573e 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import List, Sequence
from twisted.test.proto_helpers import MemoryReactor
@@ -28,7 +28,6 @@ from tests.unittest import override_config
class ServerNoticeTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -558,7 +557,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
def _check_invite_and_join_status(
self, user_id: str, expected_invites: int, expected_memberships: int
- ) -> List[RoomsForUser]:
+ ) -> Sequence[RoomsForUser]:
"""Check invite and room membership status of a user.
Args
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..4b8f889a71 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -28,8 +28,8 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
+from synapse.media.filepath import MediaFilePaths
from synapse.rest.client import devices, login, logout, profile, register, room, sync
-from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
- storage_controllers = self.hs.get_storage_controllers()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
# Create two rooms, one with a local user only and one with both a local
# and remote user.
@@ -2934,11 +2935,13 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_creation_handler.create_new_client_event(builder)
)
- self.get_success(storage_controllers.persistence.persist_event(event, context))
+ context = self.get_success(unpersisted_context.persist(event))
+
+ self.get_success(persistence.persist_event(event, context))
# Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 30f12f1bff..6c04e6c56c 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- async def check_username(username: str) -> bool:
- if username == "allowed":
- return True
+ async def check_username(
+ localpart: str,
+ guest_access_token: Optional[str] = None,
+ assigned_user_id: Optional[str] = None,
+ inhibit_user_in_use_error: bool = False,
+ ) -> None:
+ if localpart == "allowed":
+ return
raise SynapseError(
400,
"User ID already taken.",
@@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
)
handler = self.hs.get_registration_handler()
- handler.check_username = check_username
+ handler.check_username = check_username # type: ignore[assignment]
def test_username_available(self) -> None:
"""
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 88f255c9ee..2b05dffc7d 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config
class PasswordResetTestCase(unittest.HomeserverTestCase):
-
servlets = [
account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -408,7 +407,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
class DeactivateTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -492,7 +490,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
class WhoamiTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -567,7 +564,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
account.register_servlets,
login.register_servlets,
@@ -1193,7 +1189,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
return {}
# Register a mock that will return the expected result depending on the remote.
- self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
+ self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment]
# Check that we've got the correct response from the client-side endpoint.
self._test_status(
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 208ec44829..0d8fe77b88 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
-from tests.server import FakeChannel, make_request
+from tests.server import FakeChannel
from tests.unittest import override_config, skip_unless
@@ -43,13 +43,15 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
super().__init__(hs)
self.recaptcha_attempts: List[Tuple[dict, str]] = []
+ def is_enabled(self) -> bool:
+ return True
+
def check_auth(self, authdict: dict, clientip: str) -> Any:
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)
class FallbackAuthTests(unittest.HomeserverTestCase):
-
servlets = [
auth.register_servlets,
register.register_servlets,
@@ -57,7 +59,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
hijack_auth = False
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["enable_registration_captcha"] = True
@@ -1319,16 +1320,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
- # Now try to exchange the login token
- channel = make_request(
- self.hs.get_reactor(),
- self.site,
- "POST",
- "/login",
- content={"type": "m.login.token", "token": login_token},
- )
- # It should have failed
- self.assertEqual(channel.code, 403)
+ # Now try to exchange the login token, it should fail.
+ self.helper.login_via_token(login_token, 403)
@override_config(
{
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index d1751e1557..c16e8d43f4 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -26,7 +26,6 @@ from tests.unittest import override_config
class CapabilitiesTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
capabilities.register_servlets,
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index b1ca81a911..bb845179d3 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -38,7 +38,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
hijack_auth = False
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["form_secret"] = "123abc"
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index 7a88aa2cda..6490e883bf 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -28,7 +28,6 @@ from tests.unittest import override_config
class DirectoryTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
directory.register_servlets,
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index 9fa1f82dfe..f31ebc8021 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -26,7 +26,6 @@ from tests import unittest
class EphemeralMessageTestCase(unittest.HomeserverTestCase):
-
user_id = "@user:test"
servlets = [
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index a9b7db9db2..54df2a252c 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -38,7 +38,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["enable_registration_captcha"] = False
config["enable_registration"] = True
@@ -51,7 +50,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
# register an account
self.user_id = self.register_user("sid1", "pass")
self.token = self.login(self.user_id, "pass")
@@ -142,7 +140,6 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
# register an account
self.user_id = self.register_user("sid1", "pass")
self.token = self.login(self.user_id, "pass")
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index afc8d641be..91678abf13 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -25,7 +25,6 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.HomeserverTestCase):
-
user_id = "@apple:test"
hijack_auth = True
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
@@ -63,14 +62,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_add_filter_non_local_user(self) -> None:
_is_mine = self.hs.is_mine
- self.hs.is_mine = lambda target_user: False
+ self.hs.is_mine = lambda target_user: False # type: ignore[assignment]
channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- self.hs.is_mine = _is_mine
+ self.hs.is_mine = _is_mine # type: ignore[assignment]
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index 741fecea77..8ee5489057 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -14,12 +14,21 @@
from http import HTTPStatus
+from signedjson.key import (
+ encode_verify_key_base64,
+ generate_signing_key,
+ get_verify_key,
+)
+from signedjson.sign import sign_json
+
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import keys, login
+from synapse.types import JsonDict
from tests import unittest
from tests.http.server._base import make_request_with_cancellation_test
+from tests.unittest import override_config
class KeyQueryTestCase(unittest.HomeserverTestCase):
@@ -118,3 +127,135 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertIn(bob, channel.json_body["device_keys"])
+
+ def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
+ # We only generate a master key to simplify the test.
+ master_signing_key = generate_signing_key(device_id)
+ master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key))
+
+ return {
+ "master_key": sign_json(
+ {
+ "user_id": user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_verify_key: master_verify_key},
+ },
+ user_id,
+ master_signing_key,
+ ),
+ }
+
+ def test_device_signing_with_uia(self) -> None:
+ """Device signing key upload requires UIA."""
+ password = "wonderland"
+ device_id = "ABCDEFGHI"
+ alice_id = self.register_user("alice", password)
+ alice_token = self.login("alice", password, device_id=device_id)
+
+ content = self.make_device_keys(alice_id, device_id)
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # add UI auth
+ content["auth"] = {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": alice_id},
+ "password": password,
+ "session": session,
+ }
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ @override_config({"ui_auth": {"session_timeout": "15m"}})
+ def test_device_signing_with_uia_session_timeout(self) -> None:
+ """Device signing key upload requires UIA buy passes with grace period."""
+ password = "wonderland"
+ device_id = "ABCDEFGHI"
+ alice_id = self.register_user("alice", password)
+ alice_token = self.login("alice", password, device_id=device_id)
+
+ content = self.make_device_keys(alice_id, device_id)
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ @override_config(
+ {
+ "experimental_features": {"msc3967_enabled": True},
+ "ui_auth": {"session_timeout": "15s"},
+ }
+ )
+ def test_device_signing_with_msc3967(self) -> None:
+ """Device signing key follows MSC3967 behaviour when enabled."""
+ password = "wonderland"
+ device_id = "ABCDEFGHI"
+ alice_id = self.register_user("alice", password)
+ alice_token = self.login("alice", password, device_id=device_id)
+
+ keys1 = self.make_device_keys(alice_id, device_id)
+
+ # Initial request should succeed as no existing keys are present.
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys1,
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ keys2 = self.make_device_keys(alice_id, device_id)
+
+ # Subsequent request should require UIA as keys already exist even though session_timeout is set.
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys2,
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # add UI auth
+ keys2["auth"] = {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": alice_id},
+ "password": password,
+ "session": session,
+ }
+
+ # Request should complete
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys2,
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index ff5baa9f0a..62acf4f44e 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -89,7 +89,6 @@ ADDITIONAL_LOGIN_FLOWS = [
class LoginRestServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -737,7 +736,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
class CASTestCase(unittest.HomeserverTestCase):
-
servlets = [
login.register_servlets,
]
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index 6aedc1a11c..b8187db982 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -26,7 +26,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
login.register_servlets,
admin.register_servlets,
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index b3738a0304..dcbb125a3b 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -35,15 +35,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
servlets = [presence.register_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
- presence_handler = Mock(spec=PresenceHandler)
- presence_handler.set_state.return_value = make_awaitable(None)
+ self.presence_handler = Mock(spec=PresenceHandler)
+ self.presence_handler.set_state.return_value = make_awaitable(None)
hs = self.setup_test_homeserver(
"red",
federation_http_client=None,
federation_client=Mock(),
- presence_handler=presence_handler,
+ presence_handler=self.presence_handler,
)
return hs
@@ -61,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
+ self.assertEqual(self.presence_handler.set_state.call_count, 1)
@unittest.override_config({"use_presence": False})
def test_put_presence_disabled(self) -> None:
@@ -76,4 +75,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
+ self.assertEqual(self.presence_handler.set_state.call_count, 0)
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 8de5a342ae..27c93ad761 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -30,7 +30,6 @@ from tests import unittest
class ProfileTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -324,7 +323,6 @@ class ProfileTestCase(unittest.HomeserverTestCase):
class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -404,7 +402,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 11cf3939d8..b228dba861 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
login.register_servlets,
register.register_servlets,
@@ -151,7 +150,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self) -> None:
- self.hs.config.key.macaroon_secret_key = "test"
+ self.hs.config.key.macaroon_secret_key = b"test"
self.hs.config.registration.allow_guest_access = True
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
@@ -797,7 +796,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
class AccountValidityTestCase(unittest.HomeserverTestCase):
-
servlets = [
register.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -913,7 +911,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
-
servlets = [
register.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -1132,7 +1129,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
-
servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -1166,12 +1162,15 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
"""
user_id = self.register_user("kermit_delta", "user")
- self.hs.config.account_validity.startup_job_max_delta = self.max_delta
+ self.hs.config.account_validity.account_validity_startup_job_max_delta = (
+ self.max_delta
+ )
now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+ assert res is not None
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c8a6911d5e..fbbbcb23f1 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -30,7 +30,6 @@ from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event
-from tests.unittest import override_config
class BaseRelationsTestCase(unittest.HomeserverTestCase):
@@ -403,7 +402,7 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_edit(self) -> None:
"""Test that a simple edit works."""
-
+ orig_body = {"body": "Hi!", "msgtype": "m.text"}
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
edit_event_content = {
"msgtype": "m.text",
@@ -424,9 +423,7 @@ class RelationsTestCase(BaseRelationsTestCase):
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"}
- )
+ self.assertEqual(channel.json_body["content"], orig_body)
self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
# Request the room messages.
@@ -443,7 +440,7 @@ class RelationsTestCase(BaseRelationsTestCase):
)
# Request the room context.
- # /context should return the edited event.
+ # /context should return the event.
channel = self.make_request(
"GET",
f"/rooms/{self.room}/context/{self.parent_id}",
@@ -453,7 +450,7 @@ class RelationsTestCase(BaseRelationsTestCase):
self._assert_edit_bundle(
channel.json_body["event"], edit_event_id, edit_event_content
)
- self.assertEqual(channel.json_body["event"]["content"], new_body)
+ self.assertEqual(channel.json_body["event"]["content"], orig_body)
# Request sync, but limit the timeline so it becomes limited (and includes
# bundled aggregations).
@@ -491,45 +488,11 @@ class RelationsTestCase(BaseRelationsTestCase):
edit_event_content,
)
- @override_config({"experimental_features": {"msc3925_inhibit_edit": True}})
- def test_edit_inhibit_replace(self) -> None:
- """
- If msc3925_inhibit_edit is enabled, then the original event should not be
- replaced.
- """
-
- new_body = {"msgtype": "m.text", "body": "I've been edited!"}
- edit_event_content = {
- "msgtype": "m.text",
- "body": "foo",
- "m.new_content": new_body,
- }
- channel = self._send_relation(
- RelationTypes.REPLACE,
- "m.room.message",
- content=edit_event_content,
- )
- edit_event_id = channel.json_body["event_id"]
-
- # /context should return the *original* event.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["event"]["content"], {"body": "Hi!", "msgtype": "m.text"}
- )
- self._assert_edit_bundle(
- channel.json_body["event"], edit_event_id, edit_event_content
- )
-
def test_multi_edit(self) -> None:
"""Test that multiple edits, including attempts by people who
shouldn't be allowed, are correctly handled.
"""
-
+ orig_body = orig_body = {"body": "Hi!", "msgtype": "m.text"}
self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
@@ -570,7 +533,7 @@ class RelationsTestCase(BaseRelationsTestCase):
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(channel.json_body["event"]["content"], new_body)
+ self.assertEqual(channel.json_body["event"]["content"], orig_body)
self._assert_edit_bundle(
channel.json_body["event"], edit_event_id, edit_event_content
)
@@ -642,6 +605,7 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_edit_edit(self) -> None:
"""Test that an edit cannot be edited."""
+ orig_body = {"body": "Hi!", "msgtype": "m.text"}
new_body = {"msgtype": "m.text", "body": "Initial edit"}
edit_event_content = {
"msgtype": "m.text",
@@ -675,14 +639,12 @@ class RelationsTestCase(BaseRelationsTestCase):
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"}
- )
+ self.assertEqual(channel.json_body["content"], orig_body)
# The relations information should not include the edit to the edit.
self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
- # /context should return the event updated for the *first* edit
+ # /context should return the bundled edit for the *first* edit
# (The edit to the edit should be ignored.)
channel = self.make_request(
"GET",
@@ -690,7 +652,7 @@ class RelationsTestCase(BaseRelationsTestCase):
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(channel.json_body["event"]["content"], new_body)
+ self.assertEqual(channel.json_body["event"]["content"], orig_body)
self._assert_edit_bundle(
channel.json_body["event"], edit_event_id, edit_event_content
)
@@ -1080,48 +1042,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
]
assert_bundle(self._find_event_in_chunk(chunk))
- def test_annotation(self) -> None:
- """
- Test that annotations get correctly bundled.
- """
- # Setup by sending a variety of relations.
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-
- def assert_annotations(bundled_aggregations: JsonDict) -> None:
- self.assertEqual(
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- bundled_aggregations,
- )
-
- self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
-
- def test_annotation_to_annotation(self) -> None:
- """Any relation to an annotation should be ignored."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- event_id = channel.json_body["event_id"]
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=event_id
- )
-
- # Fetch the initial annotation event to see if it has bundled aggregations.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v3/rooms/{self.room}/event/{event_id}",
- access_token=self.user_token,
- )
- self.assertEquals(200, channel.code, channel.json_body)
- # The first annotationt should not have any bundled aggregations.
- self.assertNotIn("m.relations", channel.json_body["unsigned"])
-
def test_reference(self) -> None:
"""
Test that references get correctly bundled.
@@ -1138,7 +1058,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
def test_thread(self) -> None:
"""
@@ -1183,7 +1103,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
- self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7)
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 6)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
@@ -1208,9 +1128,10 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_2 = channel.json_body["event_id"]
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2
+ channel = self._send_relation(
+ RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_2
)
+ reference_event_id = channel.json_body["event_id"]
def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
@@ -1235,17 +1156,15 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assert_dict(
{
"m.relations": {
- RelationTypes.ANNOTATION: {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 1},
- ]
+ RelationTypes.REFERENCE: {
+ "chunk": [{"event_id": reference_event_id}]
},
}
},
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 6)
def test_nested_thread(self) -> None:
"""
@@ -1330,7 +1249,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
thread_summary = relations_dict[RelationTypes.THREAD]
self.assertIn("latest_event", thread_summary)
latest_event_in_thread = thread_summary["latest_event"]
- self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!")
# The latest event in the thread should have the edit appear under the
# bundled aggregations.
self.assertDictContainsSubset(
@@ -1363,10 +1281,11 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_id = channel.json_body["event_id"]
- # Annotate the thread.
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+ # Make a reference to the thread.
+ channel = self._send_relation(
+ RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_id
)
+ reference_event_id = channel.json_body["event_id"]
channel = self.make_request(
"GET",
@@ -1377,9 +1296,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assertEqual(
channel.json_body["unsigned"].get("m.relations"),
{
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
+ RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
},
)
@@ -1396,9 +1313,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assertEqual(
thread_message["unsigned"].get("m.relations"),
{
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
+ RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
},
)
@@ -1410,7 +1325,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
Note that the spec allows for a server to return additional fields beyond
what is specified.
"""
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test")
+ reference_event_id = channel.json_body["event_id"]
# Note that the sync filter does not include "unsigned" as a field.
filter = urllib.parse.quote_plus(
@@ -1428,7 +1344,12 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# Ensure there's bundled aggregations on it.
self.assertIn("unsigned", parent_event)
- self.assertIn("m.relations", parent_event["unsigned"])
+ self.assertEqual(
+ parent_event["unsigned"].get("m.relations"),
+ {
+ RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
+ },
+ )
class RelationIgnoredUserTestCase(BaseRelationsTestCase):
@@ -1475,53 +1396,8 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
return before_aggregations[relation_type], after_aggregations[relation_type]
- def test_annotation(self) -> None:
- """Annotations should ignore"""
- # Send 2 from us, 2 from the to be ignored user.
- allowed_event_ids = []
- ignored_event_ids = []
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- allowed_event_ids.append(channel.json_body["event_id"])
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b")
- allowed_event_ids.append(channel.json_body["event_id"])
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="a",
- access_token=self.user2_token,
- )
- ignored_event_ids.append(channel.json_body["event_id"])
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="c",
- access_token=self.user2_token,
- )
- ignored_event_ids.append(channel.json_body["event_id"])
-
- before_aggregations, after_aggregations = self._test_ignored_user(
- RelationTypes.ANNOTATION, allowed_event_ids, ignored_event_ids
- )
-
- self.assertCountEqual(
- before_aggregations["chunk"],
- [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- {"type": "m.reaction", "key": "c", "count": 1},
- ],
- )
-
- self.assertCountEqual(
- after_aggregations["chunk"],
- [
- {"type": "m.reaction", "key": "a", "count": 1},
- {"type": "m.reaction", "key": "b", "count": 1},
- ],
- )
-
def test_reference(self) -> None:
- """Annotations should ignore"""
+ """Aggregations should exclude reference relations from ignored users"""
channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
allowed_event_ids = [channel.json_body["event_id"]]
@@ -1544,7 +1420,7 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
)
def test_thread(self) -> None:
- """Annotations should ignore"""
+ """Aggregations should exclude thread releations from ignored users"""
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
allowed_event_ids = [channel.json_body["event_id"]]
@@ -1618,43 +1494,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
for t in threads
]
- def test_redact_relation_annotation(self) -> None:
- """
- Test that annotations of an event are properly handled after the
- annotation is redacted.
-
- The redacted relation should not be included in bundled aggregations or
- the response to relations.
- """
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- to_redact_event_id = channel.json_body["event_id"]
-
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- unredacted_event_id = channel.json_body["event_id"]
-
- # Both relations should exist.
- event_ids = self._get_related_events()
- relations = self._get_bundled_aggregations()
- self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
- self.assertEquals(
- relations["m.annotation"],
- {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
- )
-
- # Redact one of the reactions.
- self._redact(to_redact_event_id)
-
- # The unredacted relation should still exist.
- event_ids = self._get_related_events()
- relations = self._get_bundled_aggregations()
- self.assertEquals(event_ids, [unredacted_event_id])
- self.assertEquals(
- relations["m.annotation"],
- {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
- )
-
def test_redact_relation_thread(self) -> None:
"""
Test that thread replies are properly handled after the thread reply redacted.
@@ -1775,14 +1614,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
is redacted.
"""
# Add a relation
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+ channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test")
related_event_id = channel.json_body["event_id"]
# The relations should exist.
event_ids = self._get_related_events()
relations = self._get_bundled_aggregations()
self.assertEqual(len(event_ids), 1)
- self.assertIn(RelationTypes.ANNOTATION, relations)
+ self.assertIn(RelationTypes.REFERENCE, relations)
# Redact the original event.
self._redact(self.parent_id)
@@ -1792,8 +1631,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [related_event_id])
self.assertEquals(
- relations["m.annotation"],
- {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
+ relations[RelationTypes.REFERENCE],
+ {"chunk": [{"event_id": related_event_id}]},
)
def test_redact_parent_thread(self) -> None:
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index c0eb5d01a6..8dbd64be55 100644
--- a/tests/rest/client/test_rendezvous.py
+++ b/tests/rest/client/test_rendezvous.py
@@ -25,7 +25,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
class RendezvousServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
rendezvous.register_servlets,
]
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 7cb1017a4a..1250685d39 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -73,6 +73,18 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
data = {"reason": None, "score": None}
self._assert_status(400, data)
+ def test_cannot_report_nonexistent_event(self) -> None:
+ """
+ Tests that we don't accept event reports for events which do not exist.
+ """
+ channel = self.make_request(
+ "POST",
+ f"rooms/{self.room_id}/report/$nonsenseeventid:test",
+ {"reason": "i am very sad"},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(404, channel.code, msg=channel.result["body"])
+
def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request(
"POST", self.report_path, data, access_token=self.other_user_tok
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 9c8c1889d3..d3e06bf6b3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -136,6 +136,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send a first event, which should be filtered out at the end of the test.
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
first_event_id = resp.get("event_id")
+ assert isinstance(first_event_id, str)
# Advance the time by 2 days. We're using the default retention policy, therefore
# after this the first event will still be valid.
@@ -144,6 +145,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send another event, which shouldn't get filtered out.
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
valid_event_id = resp.get("event_id")
+ assert isinstance(valid_event_id, str)
# Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one.
@@ -229,7 +231,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Check that we can still access state events that were sent before the event that
# has been purged.
- self.get_event(room_id, create_event.event_id)
+ self.get_event(room_id, bool(create_event))
def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict:
event = self.get_success(self.store.get_event(event_id, allow_none=True))
@@ -238,7 +240,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self.assertIsNone(event)
return {}
- self.assertIsNotNone(event)
+ assert event is not None
time_now = self.clock.time_msec()
serialized = self.serializer.serialize_event(event, time_now)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 9222cab198..a4900703c4 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -65,7 +65,6 @@ class RoomBase(unittest.HomeserverTestCase):
servlets = [room.register_servlets, room.register_deprecated_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
self.hs = self.setup_test_homeserver(
"red",
federation_http_client=None,
@@ -92,7 +91,6 @@ class RoomPermissionsTestCase(RoomBase):
rmcreator_id = "@notme:red"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
self.helper.auth_user_id = self.rmcreator_id
# create some rooms under the name rmcreator_id
self.uncreated_rmid = "!aa:test"
@@ -715,7 +713,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(33, channel.resource_usage.db_txn_count)
+ self.assertEqual(30, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -728,7 +726,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(36, channel.resource_usage.db_txn_count)
+ self.assertEqual(32, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
@@ -1127,7 +1125,6 @@ class RoomInviteRatelimitTestCase(RoomBase):
class RoomJoinTestCase(RoomBase):
-
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -2102,7 +2099,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
hijack_auth = False
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
# Register the user who does the searching
self.user_id2 = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
@@ -2195,7 +2191,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -2203,7 +2198,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
self.url = b"/_matrix/client/r0/publicRooms"
config = self.default_config()
@@ -2225,7 +2219,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -2233,7 +2226,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["allow_public_rooms_without_auth"] = True
self.hs = self.setup_test_homeserver(config=config)
@@ -2414,7 +2406,6 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -2983,7 +2974,6 @@ class RelationsTestCase(PaginationTestCase):
class ContextTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -3359,7 +3349,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
class ThreepidInviteTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -3382,8 +3371,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
- self.hs.get_identity_handler().lookup_3pid = Mock(
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
+ self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
)
@@ -3438,13 +3427,14 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
"""
Test allowing/blocking threepid invites with a spam-check module.
- In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
+ In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
+ """
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
- self.hs.get_identity_handler().lookup_3pid = Mock(
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
+ self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
)
@@ -3563,8 +3553,10 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
)
event.internal_metadata.outlier = True
+ persistence = self._storage_controllers.persistence
+ assert persistence is not None
self.get_success(
- self._storage_controllers.persistence.persist_event(
+ persistence.persist_event(
event, EventContext.for_outlier(self._storage_controllers)
)
)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index c807a37bc2..8d2cdf8751 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase):
def test_invite_3pid(self) -> None:
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
identity_handler = self.hs.get_identity_handler()
- identity_handler.lookup_3pid = Mock(
+ identity_handler.lookup_3pid = Mock( # type: ignore[assignment]
side_effect=AssertionError("This should not get called")
)
@@ -222,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase):
event_source.get_new_events(
user=UserID.from_string(self.other_user_id),
from_key=0,
- limit=None,
+ limit=10,
room_ids=[room_id],
is_guest=False,
)
@@ -286,6 +286,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.banned_user_id,
)
)
+ assert event is not None
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)
@@ -321,6 +322,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.banned_user_id,
)
)
+ assert event is not None
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index b9047194dd..9c876c7a32 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -41,7 +41,6 @@ from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
-
user_id = "@apple:test"
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -191,7 +190,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
class SyncTypingTests(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -892,7 +890,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
class ExcludeRoomTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3325d43a2f..3b99513707 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -137,6 +137,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent.
"""
+
# patch the rules module with a Mock which will return False for some event
# types
async def check(
@@ -243,6 +244,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_modify_event(self) -> None:
"""The module can return a modified version of the event"""
+
# first patch the event checker so that it will modify the event
async def check(
ev: EventBase, state: StateMap[EventBase]
@@ -275,6 +277,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages."""
+
# first patch the event checker so that it will modify the event
async def check(
ev: EventBase, state: StateMap[EventBase]
@@ -425,7 +428,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
async def test_fn(
event: EventBase, state_events: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
- if event.is_state and event.type == EventTypes.PowerLevels:
+ if event.is_state() and event.type == EventTypes.PowerLevels:
await api.create_and_send_event_into_room(
{
"room_id": event.room_id,
@@ -931,3 +934,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Check that the mock was called with the right parameters
self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ def test_on_add_and_remove_user_third_party_identifier(self) -> None:
+ """Tests that the on_add_user_third_party_identifier and
+ on_remove_user_third_party_identifier module callbacks are called
+ just before associating and removing a 3PID to/from an account.
+ """
+ # Pretend to be a Synapse module and register both callbacks as mocks.
+ third_party_rules = self.hs.get_third_party_event_rules()
+ on_add_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ on_remove_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_add_user_third_party_identifier_callback_mock
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_remove_user_third_party_identifier_callback_mock
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked add callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_add_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_add_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ # Now remove the 3PID from the user
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked remove callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ def test_on_remove_user_third_party_identifier_is_called_on_deactivate(
+ self,
+ ) -> None:
+ """Tests that the on_remove_user_third_party_identifier module callback is called
+ when a user is deactivated and their third-party ID associations are deleted.
+ """
+ # Pretend to be a Synapse module and register both callbacks as mocks.
+ third_party_rules = self.hs.get_third_party_event_rules()
+ on_remove_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_remove_user_third_party_identifier_callback_mock
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Now deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "deactivated": True,
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked remove callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 3086e1b565..d8dc56261a 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -39,15 +39,23 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
- self.mock_key = "foo"
+
+ # Here we make sure that we're setting all the fields that HttpTransactionCache
+ # uses to build the transaction key.
+ self.mock_request = Mock()
+ self.mock_request.path = b"/foo/bar"
+ self.mock_requester = Mock()
+ self.mock_requester.app_service = None
+ self.mock_requester.is_guest = False
+ self.mock_requester.access_token_id = 1234
@defer.inlineCallbacks
def test_executes_given_function(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
- res = yield self.cache.fetch_or_execute(
- self.mock_key, cb, "some_arg", keyword="arg"
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
)
cb.assert_called_once_with("some_arg", keyword="arg")
self.assertEqual(res, self.mock_http_response)
@@ -58,8 +66,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
- res = yield self.cache.fetch_or_execute(
- self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request,
+ self.mock_requester,
+ cb,
+ "some_arg",
+ keyword="arg",
+ changing_args=i,
)
self.assertEqual(res, self.mock_http_response)
# expect only a single call to do the work
@@ -77,7 +90,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test() -> Generator["defer.Deferred[Any]", object, None]:
with LoggingContext("c") as c1:
- res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
self.assertIs(current_context(), c1)
self.assertEqual(res, (1, {}))
@@ -106,12 +121,16 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
with LoggingContext("test") as test_context:
try:
- yield self.cache.fetch_or_execute(self.mock_key, cb)
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
except Exception as e:
self.assertEqual(e.args[0], "boo")
self.assertIs(current_context(), test_context)
- res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
self.assertEqual(res, self.mock_http_response)
self.assertIs(current_context(), test_context)
@@ -134,29 +153,39 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
with LoggingContext("test") as test_context:
try:
- yield self.cache.fetch_or_execute(self.mock_key, cb)
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
except Exception as e:
self.assertEqual(e.args[0], "boo")
self.assertIs(current_context(), test_context)
- res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
self.assertEqual(res, self.mock_http_response)
self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
- yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "an arg"
+ )
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
- yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "an arg"
+ )
# still using cache
cb.assert_called_once_with("an arg")
self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
- yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "an arg"
+ )
# no longer using cache
self.assertEqual(cb.call_count, 2)
self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")])
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 5ec343dd7f..0b4c691318 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -84,7 +84,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.room_id, EventTypes.Tombstone, ""
)
)
- self.assertIsNotNone(tombstone_event)
+ assert tombstone_event is not None
self.assertEqual(new_room_id, tombstone_event.content["replacement_room"])
# Check that the new room exists.
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode
import attr
from typing_extensions import Literal
+from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.resource import Resource
from twisted.web.server import Site
@@ -67,6 +68,7 @@ class RestHelper:
"""
hs: HomeServer
+ reactor: MemoryReactorClock
site: Site
auth_user_id: Optional[str]
@@ -142,7 +144,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -216,7 +218,7 @@ class RestHelper:
data["reason"] = reason
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -313,7 +315,7 @@ class RestHelper:
data.update(extra_data or {})
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -394,7 +396,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -433,7 +435,7 @@ class RestHelper:
path = path + f"?access_token={tok}"
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
path,
@@ -488,7 +490,7 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
- channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
+ channel = make_request(self.reactor, self.site, method, path, content)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
@@ -573,8 +575,8 @@ class RestHelper:
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
- self.hs.get_reactor(),
- FakeSite(resource, self.hs.get_reactor()),
+ self.reactor,
+ FakeSite(resource, self.reactor),
"POST",
path,
content=image_data,
@@ -603,7 +605,7 @@ class RestHelper:
expect_code: The return code to expect from attempting the whoami request
"""
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
"account/whoami",
@@ -642,7 +644,7 @@ class RestHelper:
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC
- Returns the result of the final token login.
+ Returns the result of the final token login and the fake authorization grant.
Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
@@ -672,10 +674,28 @@ class RestHelper:
assert m, channel.text_body
login_token = m.group(1)
- # finally, submit the matrix login token to the login API, which gives us our
- # matrix access token and device id.
+ return self.login_via_token(login_token, expected_status), grant
+
+ def login_via_token(
+ self,
+ login_token: str,
+ expected_status: int = 200,
+ ) -> JsonDict:
+ """Submit the matrix login token to the login API, which gives us our
+ matrix access token and device id.Log in (as a new user) via OIDC
+
+ Returns the result of the token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
"/login",
@@ -684,7 +704,7 @@ class RestHelper:
assert (
channel.code == expected_status
), f"unexpected status in response: {channel.code}"
- return channel.json_body, grant
+ return channel.json_body
def auth_via_oidc(
self,
@@ -805,7 +825,7 @@ class RestHelper:
with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
callback_uri,
@@ -849,7 +869,7 @@ class RestHelper:
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
uri,
@@ -867,7 +887,7 @@ class RestHelper:
location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
urllib.parse.urlunsplit(("", "") + parts[2:]),
@@ -900,9 +920,7 @@ class RestHelper:
+ urllib.parse.urlencode({"session": ui_auth_session_id})
)
# hit the redirect url (which will issue a cookie and state)
- channel = make_request(
- self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
- )
+ channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
# that should serve a confirmation page
assert channel.code == HTTPStatus.OK, channel.text_body
channel.extract_cookies(cookies)
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index 23f227aed6..b59d9dfd4d 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -31,7 +31,6 @@ from tests.utils import MockClock
class MediaRetentionTestCase(unittest.HomeserverTestCase):
-
ONE_DAY_IN_MS = 24 * 60 * 60 * 1000
THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/test_url_preview.py
index 2c321f8d04..e91dc581c2 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -26,8 +26,8 @@ from twisted.internet.interfaces import IAddress, IResolutionReceiver
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from synapse.config.oembed import OEmbedEndpointConfig
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
-from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
+from synapse.rest.media.media_repository_resource import MediaRepositoryResource
+from synapse.rest.media.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -58,7 +58,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["url_preview_enabled"] = True
config["max_spider_size"] = 9999999
@@ -83,7 +82,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
config["media_store_path"] = self.media_store_path
provider_config = {
- "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
@@ -118,7 +117,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
self.media_repo = hs.get_media_repository_resource()
self.preview_url = self.media_repo.children[b"preview_url"]
@@ -133,7 +131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
addressTypes: Optional[Sequence[Type[IAddress]]] = None,
transportSemantics: str = "TCP",
) -> IResolutionReceiver:
-
resolution = HostResolution(hostName)
resolutionReceiver.resolutionBegan(resolution)
if hostName not in self.lookups:
@@ -660,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""If the preview image doesn't exist, ensure some data is returned."""
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
- end_content = (
+ result = (
b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
)
@@ -681,8 +678,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
b'Content-Type: text/html; charset="utf8"\r\n\r\n'
)
- % (len(end_content),)
- + end_content
+ % (len(result),)
+ + result
)
self.pump()
@@ -691,6 +688,44 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# The image should not be in the result.
self.assertNotIn("og:image", channel.json_body)
+ def test_oembed_failure(self) -> None:
+ """If the autodiscovered oEmbed URL fails, ensure some data is returned."""
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ result = b"""
+ <title>oEmbed Autodiscovery Fail</title>
+ <link rel="alternate" type="application/json+oembed"
+ href="http://example.com/oembed?url=http%3A%2F%2Fmatrix.org&format=json"
+ title="matrixdotorg" />
+ """
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # The image should not be in the result.
+ self.assertEqual(channel.json_body["og:title"], "oEmbed Autodiscovery Fail")
+
def test_data_url(self) -> None:
"""
Requesting to preview a data URL is not supported.
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 22f99c6ab1..3285f2433c 100644
--- a/tests/scripts/test_new_matrix_user.py
+++ b/tests/scripts/test_new_matrix_user.py
@@ -12,29 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import List, Optional
from unittest.mock import Mock, patch
from synapse._scripts.register_new_matrix_user import request_registration
+from synapse.types import JsonDict
from tests.unittest import TestCase
class RegisterTestCase(TestCase):
- def test_success(self):
+ def test_success(self) -> None:
"""
The script will fetch a nonce, and then generate a MAC with it, and then
post that MAC.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 200
r.json = lambda: {"nonce": "a"}
return r
- def post(url, json=None, verify=None):
+ def post(
+ url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+ ) -> Mock:
# Make sure we are sent the correct info
+ assert json is not None
self.assertEqual(json["username"], "user")
self.assertEqual(json["password"], "pass")
self.assertEqual(json["nonce"], "a")
@@ -70,12 +74,12 @@ class RegisterTestCase(TestCase):
# sys.exit shouldn't have been called.
self.assertEqual(err_code, [])
- def test_failure_nonce(self):
+ def test_failure_nonce(self) -> None:
"""
If the script fails to fetch a nonce, it throws an error and quits.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 404
r.reason = "Not Found"
@@ -107,20 +111,23 @@ class RegisterTestCase(TestCase):
self.assertIn("ERROR! Received 404 Not Found", out)
self.assertNotIn("Success!", out)
- def test_failure_post(self):
+ def test_failure_post(self) -> None:
"""
The script will fetch a nonce, and then if the final POST fails, will
report an error and quit.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 200
r.json = lambda: {"nonce": "a"}
return r
- def post(url, json=None, verify=None):
+ def post(
+ url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+ ) -> Mock:
# Make sure we are sent the correct info
+ assert json is not None
self.assertEqual(json["username"], "user")
self.assertEqual(json["password"], "pass")
self.assertEqual(json["nonce"], "a")
diff --git a/tests/server.py b/tests/server.py
index b1730fcc8d..5de9722766 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -22,20 +22,25 @@ import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
+ Any,
+ Awaitable,
Callable,
Dict,
Iterable,
List,
MutableMapping,
Optional,
+ Sequence,
Tuple,
Type,
+ TypeVar,
Union,
+ cast,
)
from unittest.mock import Mock
import attr
-from typing_extensions import Deque
+from typing_extensions import Deque, ParamSpec
from zope.interface import implementer
from twisted.internet import address, threads, udp
@@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IAddress,
+ IConnector,
IConsumer,
IHostnameResolver,
+ IProducer,
IProtocol,
IPullProducer,
IPushProducer,
@@ -54,6 +61,8 @@ from twisted.internet.interfaces import (
IResolverSimple,
ITransport,
)
+from twisted.internet.protocol import ClientFactory, DatagramProtocol
+from twisted.python import threadpool
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers
@@ -61,6 +70,7 @@ from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
+from synapse.config.homeserver import HomeServerConfig
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -70,7 +80,7 @@ from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
-from synapse.types import JsonDict
+from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock
from tests.utils import (
@@ -88,6 +98,9 @@ from tests.utils import (
logger = logging.getLogger(__name__)
+R = TypeVar("R")
+P = ParamSpec("P")
+
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
@@ -98,12 +111,14 @@ class TimedOutException(Exception):
"""
-@implementer(IConsumer)
+@implementer(ITransport, IPushProducer, IConsumer)
@attr.s(auto_attribs=True)
class FakeChannel:
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
+
+ See twisted.web.http.HTTPChannel.
"""
site: Union[Site, "FakeSite"]
@@ -142,7 +157,7 @@ class FakeChannel:
Raises an exception if the request has not yet completed.
"""
- if not self.is_finished:
+ if not self.is_finished():
raise Exception("Request not yet completed")
return self.result["body"].decode("utf8")
@@ -165,27 +180,36 @@ class FakeChannel:
h.addRawHeader(*i)
return h
- def writeHeaders(self, version, code, reason, headers):
+ def writeHeaders(
+ self, version: bytes, code: bytes, reason: bytes, headers: Headers
+ ) -> None:
self.result["version"] = version
self.result["code"] = code
self.result["reason"] = reason
self.result["headers"] = headers
- def write(self, content: bytes) -> None:
- assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+ def write(self, data: bytes) -> None:
+ assert isinstance(data, bytes), "Should be bytes! " + repr(data)
if "body" not in self.result:
self.result["body"] = b""
- self.result["body"] += content
+ self.result["body"] += data
+
+ def writeSequence(self, data: Iterable[bytes]) -> None:
+ for x in data:
+ self.write(x)
+
+ def loseConnection(self) -> None:
+ self.unregisterProducer()
+ self.transport.loseConnection()
# Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
- def registerProducer( # type: ignore[override]
- self,
- producer: Union[IPullProducer, IPushProducer],
- streaming: bool,
- ) -> None:
- self._producer = producer
+ def registerProducer(self, producer: IProducer, streaming: bool) -> None:
+ # TODO This should ensure that the IProducer is an IPushProducer or
+ # IPullProducer, unfortunately twisted.protocols.basic.FileSender does
+ # implement those, but doesn't declare it.
+ self._producer = cast(Union[IPushProducer, IPullProducer], producer)
self.producerStreaming = streaming
def _produce() -> None:
@@ -202,6 +226,16 @@ class FakeChannel:
self._producer = None
+ def stopProducing(self) -> None:
+ if self._producer is not None:
+ self._producer.stopProducing()
+
+ def pauseProducing(self) -> None:
+ raise NotImplementedError()
+
+ def resumeProducing(self) -> None:
+ raise NotImplementedError()
+
def requestDone(self, _self: Request) -> None:
self.result["done"] = True
if isinstance(_self, SynapseRequest):
@@ -281,12 +315,12 @@ class FakeSite:
self.reactor = reactor
self.experimental_cors_msc3886 = experimental_cors_msc3886
- def getResourceFor(self, request):
+ def getResourceFor(self, request: Request) -> IResource:
return self._resource
def make_request(
- reactor,
+ reactor: MemoryReactorClock,
site: Union[Site, FakeSite],
method: Union[bytes, str],
path: Union[bytes, str],
@@ -401,25 +435,29 @@ def make_request(
return channel
-@implementer(IReactorPluggableNameResolver)
+# ISynapseReactor implies IReactorPluggableNameResolver, but explicitly
+# marking this as an implementer of the latter seems to keep mypy-zope happier.
+@implementer(IReactorPluggableNameResolver, ISynapseReactor)
class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
A MemoryReactorClock that supports callFromThread.
"""
- def __init__(self):
+ def __init__(self) -> None:
self.threadpool = ThreadPool(self)
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
- self._udp = []
+ self._udp: List[udp.Port] = []
self.lookups: Dict[str, str] = {}
- self._thread_callbacks: Deque[Callable[[], None]] = deque()
+ self._thread_callbacks: Deque[Callable[..., R]] = deque()
lookups = self.lookups
@implementer(IResolverSimple)
class FakeResolver:
- def getHostByName(self, name, timeout=None):
+ def getHostByName(
+ self, name: str, timeout: Optional[Sequence[int]] = None
+ ) -> "Deferred[str]":
if name not in lookups:
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])
@@ -430,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError()
- def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
+ def listenUDP(
+ self,
+ port: int,
+ protocol: DatagramProtocol,
+ interface: str = "",
+ maxPacketSize: int = 8196,
+ ) -> udp.Port:
p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening()
self._udp.append(p)
return p
- def callFromThread(self, callback, *args, **kwargs):
+ def callFromThread(
+ self, callable: Callable[..., Any], *args: object, **kwargs: object
+ ) -> None:
"""
Make the callback fire in the next reactor iteration.
"""
- cb = lambda: callback(*args, **kwargs)
+ cb = lambda: callable(*args, **kwargs)
# it's not safe to call callLater() here, so we append the callback to a
# separate queue.
self._thread_callbacks.append(cb)
- def getThreadPool(self):
- return self.threadpool
+ def callInThread(
+ self, callable: Callable[..., Any], *args: object, **kwargs: object
+ ) -> None:
+ raise NotImplementedError()
+
+ def suggestThreadPoolSize(self, size: int) -> None:
+ raise NotImplementedError()
+
+ def getThreadPool(self) -> "threadpool.ThreadPool":
+ # Cast to match super-class.
+ return cast(threadpool.ThreadPool, self.threadpool)
- def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
+ def add_tcp_client_callback(
+ self, host: str, port: int, callback: Callable[[], None]
+ ) -> None:
"""Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`.
@@ -457,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
self._tcp_callbacks[(host, port)] = callback
- def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
+ def connectTCP(
+ self,
+ host: str,
+ port: int,
+ factory: ClientFactory,
+ timeout: float = 30,
+ bindAddress: Optional[Tuple[str, int]] = None,
+ ) -> IConnector:
"""Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP(
@@ -470,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return conn
- def advance(self, amount):
+ def advance(self, amount: float) -> None:
# first advance our reactor's time, and run any "callLater" callbacks that
# makes ready
super().advance(amount)
@@ -498,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
class ThreadPool:
"""
Threadless thread pool.
+
+ See twisted.python.threadpool.ThreadPool
"""
- def __init__(self, reactor):
+ def __init__(self, reactor: IReactorTime):
self._reactor = reactor
- def start(self):
+ def start(self) -> None:
pass
- def stop(self):
+ def stop(self) -> None:
pass
- def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
- def _(res):
+ def callInThreadWithCallback(
+ self,
+ onResult: Callable[[bool, Union[Failure, R]], None],
+ function: Callable[P, R],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> "Deferred[None]":
+ def _(res: Any) -> None:
if isinstance(res, Failure):
onResult(False, res)
else:
onResult(True, res)
- d = Deferred()
+ d: "Deferred[None]" = Deferred()
d.addCallback(lambda x: function(*args, **kwargs))
d.addBoth(_)
self._reactor.callLater(0, d.callback, True)
@@ -533,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
for database in server.get_datastores().databases:
pool = database._db_pool
- def runWithConnection(func, *args, **kwargs):
+ def runWithConnection(
+ func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
@@ -543,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
**kwargs,
)
- def runInteraction(interaction, *args, **kwargs):
+ def runInteraction(
+ desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runInteraction,
- interaction,
+ desc,
+ func,
*args,
**kwargs,
)
- pool.runWithConnection = runWithConnection
- pool.runInteraction = runInteraction
+ pool.runWithConnection = runWithConnection # type: ignore[assignment]
+ pool.runInteraction = runInteraction # type: ignore[assignment]
# Replace the thread pool with a threadless 'thread' pool
- pool.threadpool = ThreadPool(clock._reactor)
+ pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment]
pool.running = True
# We've just changed the Databases to run DB transactions on the same
@@ -571,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
@implementer(ITransport)
-@attr.s(cmp=False)
+@attr.s(cmp=False, auto_attribs=True)
class FakeTransport:
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -586,48 +663,50 @@ class FakeTransport:
If you want bidirectional communication, you'll need two instances.
"""
- other = attr.ib()
+ other: IProtocol
"""The Protocol object which will receive any data written to this transport.
-
- :type: twisted.internet.interfaces.IProtocol
"""
- _reactor = attr.ib()
+ _reactor: IReactorTime
"""Test reactor
-
- :type: twisted.internet.interfaces.IReactorTime
"""
- _protocol = attr.ib(default=None)
+ _protocol: Optional[IProtocol] = None
"""The Protocol which is producing data for this transport. Optional, but if set
will get called back for connectionLost() notifications etc.
"""
- _peer_address: Optional[IAddress] = attr.ib(default=None)
+ _peer_address: IAddress = attr.Factory(
+ lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
+ )
"""The value to be returned by getPeer"""
- _host_address: Optional[IAddress] = attr.ib(default=None)
+ _host_address: IAddress = attr.Factory(
+ lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
+ )
"""The value to be returned by getHost"""
disconnecting = False
disconnected = False
connected = True
- buffer = attr.ib(default=b"")
- producer = attr.ib(default=None)
- autoflush = attr.ib(default=True)
+ buffer: bytes = b""
+ producer: Optional[IPushProducer] = None
+ autoflush: bool = True
- def getPeer(self) -> Optional[IAddress]:
+ def getPeer(self) -> IAddress:
return self._peer_address
- def getHost(self) -> Optional[IAddress]:
+ def getHost(self) -> IAddress:
return self._host_address
- def loseConnection(self, reason=None):
+ def loseConnection(self) -> None:
if not self.disconnecting:
- logger.info("FakeTransport: loseConnection(%s)", reason)
+ logger.info("FakeTransport: loseConnection()")
self.disconnecting = True
if self._protocol:
- self._protocol.connectionLost(reason)
+ self._protocol.connectionLost(
+ Failure(RuntimeError("FakeTransport.loseConnection()"))
+ )
# if we still have data to write, delay until that is done
if self.buffer:
@@ -638,38 +717,38 @@ class FakeTransport:
self.connected = False
self.disconnected = True
- def abortConnection(self):
+ def abortConnection(self) -> None:
logger.info("FakeTransport: abortConnection()")
if not self.disconnecting:
self.disconnecting = True
if self._protocol:
- self._protocol.connectionLost(None)
+ self._protocol.connectionLost(None) # type: ignore[arg-type]
self.disconnected = True
- def pauseProducing(self):
+ def pauseProducing(self) -> None:
if not self.producer:
return
self.producer.pauseProducing()
- def resumeProducing(self):
+ def resumeProducing(self) -> None:
if not self.producer:
return
self.producer.resumeProducing()
- def unregisterProducer(self):
+ def unregisterProducer(self) -> None:
if not self.producer:
return
self.producer = None
- def registerProducer(self, producer, streaming):
+ def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
self.producer = producer
self.producerStreaming = streaming
- def _produce():
+ def _produce() -> None:
if not self.producer:
# we've been unregistered
return
@@ -681,7 +760,7 @@ class FakeTransport:
if not streaming:
self._reactor.callLater(0.0, _produce)
- def write(self, byt):
+ def write(self, byt: bytes) -> None:
if self.disconnecting:
raise Exception("Writing to disconnecting FakeTransport")
@@ -693,11 +772,11 @@ class FakeTransport:
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
- def writeSequence(self, seq):
+ def writeSequence(self, seq: Iterable[bytes]) -> None:
for x in seq:
self.write(x)
- def flush(self, maxbytes=None):
+ def flush(self, maxbytes: Optional[int] = None) -> None:
if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol
@@ -748,17 +827,17 @@ def connect_client(
class TestHomeServer(HomeServer):
- DATASTORE_CLASS = DataStore
+ DATASTORE_CLASS = DataStore # type: ignore[assignment]
def setup_test_homeserver(
- cleanup_func,
- name="test",
- config=None,
- reactor=None,
+ cleanup_func: Callable[[Callable[[], None]], None],
+ name: str = "test",
+ config: Optional[HomeServerConfig] = None,
+ reactor: Optional[ISynapseReactor] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
- **kwargs,
-):
+ **kwargs: Any,
+) -> HomeServer:
"""
Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor.
@@ -773,13 +852,14 @@ def setup_test_homeserver(
HomeserverTestCase.
"""
if reactor is None:
- from twisted.internet import reactor
+ from twisted.internet import reactor as _reactor
+
+ reactor = cast(ISynapseReactor, _reactor)
if config is None:
config = default_config(name, parse=True)
config.caches.resize_all_caches()
- config.ldap_enabled = False
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
@@ -830,6 +910,8 @@ def setup_test_homeserver(
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine):
+ import psycopg2.extensions
+
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
@@ -837,6 +919,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
+ assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
@@ -865,14 +948,15 @@ def setup_test_homeserver(
hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine):
- database = hs.get_datastores().databases[0]
+ database_pool = hs.get_datastores().databases[0]
# We need to do cleanup on PostgreSQL
- def cleanup():
+ def cleanup() -> None:
import psycopg2
+ import psycopg2.extensions
# Close all the db pools
- database._db_pool.close()
+ database_pool._db_pool.close()
dropped = False
@@ -884,6 +968,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
+ assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True
cur = db_conn.cursor()
@@ -916,23 +1001,23 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- async def hash(p):
+ async def hash(p: str) -> str:
return hashlib.md5(p.encode("utf8")).hexdigest()
- hs.get_auth_handler().hash = hash
+ hs.get_auth_handler().hash = hash # type: ignore[assignment]
- async def validate_hash(p, h):
+ async def validate_hash(p: str, h: str) -> bool:
return hashlib.md5(p.encode("utf8")).hexdigest() == h
- hs.get_auth_handler().validate_hash = validate_hash
+ hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
# Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs)
# Load any configured modules into the homeserver
module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
+ for module, module_config in hs.config.modules.loaded_modules:
+ module(config=module_config, api=module_api)
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 58b399a043..3fdf5a6d52 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -14,14 +14,17 @@
import os
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
class ConsentNoticesTests(unittest.HomeserverTestCase):
-
servlets = [
sync.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -29,8 +32,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
-
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
tmpdir = self.mktemp()
os.mkdir(tmpdir)
self.consent_notice_message = "consent %(consent_uri)s"
@@ -53,15 +55,13 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
"room_name": "Server Notices",
}
- hs = self.setup_test_homeserver(config=config)
-
- return hs
+ return self.setup_test_homeserver(config=config)
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("bob", "abc123")
self.access_token = self.login("bob", "abc123")
- def test_get_sync_message(self):
+ def test_get_sync_message(self) -> None:
"""
When user consent server notices are enabled, a sync will cause a notice
to fire (in a room which the user is invited to). The notice contains
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index dadc6efcbf..d2bfa53eda 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -24,6 +24,8 @@ from synapse.server import HomeServer
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
+from synapse.server_notices.server_notices_sender import ServerNoticesSender
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -33,7 +35,7 @@ from tests.utils import default_config
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = default_config("test")
config.update(
@@ -57,14 +59,15 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.server_notices_sender = self.hs.get_server_notices_sender()
+ server_notices_sender = self.hs.get_server_notices_sender()
+ assert isinstance(server_notices_sender, ServerNoticesSender)
# relying on [1] is far from ideal, but the only case where
# ResourceLimitsServerNotices class needs to be isolated is this test,
# general code should never have a reason to do so ...
- self._rlsn = self.server_notices_sender._server_notices[1]
- if not isinstance(self._rlsn, ResourceLimitsServerNotices):
- raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+ rlsn = list(server_notices_sender._server_notices)[1]
+ assert isinstance(rlsn, ResourceLimitsServerNotices)
+ self._rlsn = rlsn
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(1000)
@@ -86,39 +89,43 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment]
@override_config({"hs_disabled": True})
- def test_maybe_send_server_notice_disabled_hs(self):
+ def test_maybe_send_server_notice_disabled_hs(self) -> None:
"""If the HS is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
@override_config({"limit_usage_by_mau": False})
- def test_maybe_send_server_notice_to_user_flag_off(self):
+ def test_maybe_send_server_notice_to_user_flag_off(self) -> None:
"""If mau limiting is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
"""Test when user has blocked notice, but should have it removed"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
- self._rlsn._store.get_events = Mock(
+ self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
- self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
+ maybe_get_notice_room_for_user = (
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user
+ )
+ assert isinstance(maybe_get_notice_room_for_user, Mock)
+ maybe_get_notice_room_for_user.assert_called_once()
self._send_notice.assert_called_once()
- def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None:
"""
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
@@ -126,7 +133,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
- self._rlsn._store.get_events = Mock(
+ self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
)
@@ -134,11 +141,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
+ def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None:
"""
Test when user does not have blocked notice, but should have one
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
@@ -147,11 +154,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
# Would be better to check contents, but 2 calls == set blocking event
self.assertEqual(self._send_notice.call_count, 2)
- def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
+ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None:
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
@@ -159,12 +166,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
+ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None:
"""
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
self._rlsn._store.user_last_seen_monthly_active = Mock(
@@ -175,12 +182,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
@override_config({"mau_limit_alerting": False})
- def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
+ def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(
+ self,
+ ) -> None:
"""
Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
@@ -191,11 +200,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 0)
@override_config({"mau_limit_alerting": False})
- def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
+ def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None:
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
@@ -207,26 +216,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 2)
@override_config({"mau_limit_alerting": False})
- def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
+ def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(
+ self,
+ ) -> None:
"""
When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state.
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
)
- self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
+ self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment]
return_value=make_awaitable((True, []))
)
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
- self._rlsn._store.get_events = Mock(
+ self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -242,7 +253,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> JsonDict:
c = super().default_config()
c["server_notices"] = {
"system_mxid_localpart": "server",
@@ -257,20 +268,22 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
- self.server_notices_sender = self.hs.get_server_notices_sender()
self.server_notices_manager = self.hs.get_server_notices_manager()
self.event_source = self.hs.get_event_sources()
+ server_notices_sender = self.hs.get_server_notices_sender()
+ assert isinstance(server_notices_sender, ServerNoticesSender)
+
# relying on [1] is far from ideal, but the only case where
# ResourceLimitsServerNotices class needs to be isolated is this test,
# general code should never have a reason to do so ...
- self._rlsn = self.server_notices_sender._server_notices[1]
- if not isinstance(self._rlsn, ResourceLimitsServerNotices):
- raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+ rlsn = list(server_notices_sender._server_notices)[1]
+ assert isinstance(rlsn, ResourceLimitsServerNotices)
+ self._rlsn = rlsn
self.user_id = "@user_id:test"
- def test_server_notice_only_sent_once(self):
+ def test_server_notice_only_sent_once(self) -> None:
self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
self.store.user_last_seen_monthly_active = Mock(
@@ -306,7 +319,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.assertEqual(count, 1)
- def test_no_invite_without_notice(self):
+ def test_no_invite_without_notice(self) -> None:
"""Tests that a user doesn't get invited to a server notices room without a
server notice being sent.
@@ -328,7 +341,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
m.assert_called_once_with(user_id)
- def test_invite_with_notice(self):
+ def test_invite_with_notice(self) -> None:
"""Tests that, if the MAU limit is hit, the server notices user invites each user
to a room in which it has sent a notice.
"""
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
index 373707b275..b6d5c474b0 100644
--- a/tests/storage/databases/main/test_deviceinbox.py
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase
class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
-
servlets = [
admin.register_servlets,
devices.register_servlets,
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 9f33afcca0..9606ecc43b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -120,6 +120,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# Persist the event which should invalidate or prefill the
# `have_seen_event` cache so we don't return stale values.
persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
self.get_success(
persistence.persist_event(
event,
diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
index ac77aec003..71db47405e 100644
--- a/tests/storage/databases/main/test_receipts.py
+++ b/tests/storage/databases/main/test_receipts.py
@@ -26,7 +26,6 @@ from tests.unittest import HomeserverTestCase
class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
-
servlets = [
admin.register_servlets,
room.register_servlets,
@@ -62,6 +61,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
keys and expected receipt key-values after duplicate receipts have been
removed.
"""
+
# First, undo the background update.
def drop_receipts_unique_index(txn: LoggingTransaction) -> None:
txn.execute(f"DROP INDEX IF EXISTS {index_name}")
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 3108ca3444..dbd8f3a85e 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -27,7 +27,6 @@ from tests.unittest import HomeserverTestCase
class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
-
servlets = [
admin.register_servlets,
room.register_servlets,
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 1bfd11ceae..b12691a9d3 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -140,3 +140,25 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# No one ignores the user now.
self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", set())
+
+ def test_ignoring_users_with_latest_stream_ids(self) -> None:
+ """Test that ignoring users updates the latest stream ID for the ignored
+ user list account data."""
+
+ def get_latest_ignore_streampos(user_id: str) -> Optional[int]:
+ return self.get_success(
+ self.store.get_latest_stream_id_for_global_account_data_by_type_for_user(
+ user_id, AccountDataTypes.IGNORED_USER_LIST
+ )
+ )
+
+ self.assertIsNone(get_latest_ignore_streampos("@user:test"))
+
+ self._update_ignore_list("@other:test", "@another:remote")
+
+ self.assertEqual(get_latest_ignore_streampos("@user:test"), 2)
+
+ # Add one user, remove one user, and leave one user.
+ self._update_ignore_list("@foo:test", "@another:remote")
+
+ self.assertEqual(get_latest_ignore_streampos("@user:test"), 3)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index d570684c99..7de109966d 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -43,8 +43,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
self.requester = create_requester(self.user)
- info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
- self.room_id = info["room_id"]
+ self.room_id, _, _ = self.get_success(
+ self.room_creator.create_room(self.requester, {})
+ )
def run_background_update(self) -> None:
"""Re run the background update to clean up the extremities."""
@@ -275,10 +276,9 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
self.requester = create_requester(self.user)
- info, _ = self.get_success(
+ self.room_id, _, _ = self.get_success(
self.room_creator.create_room(self.requester, {"visibility": "public"})
)
- self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 7f7f4ef892..cd0079871c 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -656,7 +656,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c070278db8..e39b63edac 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -389,6 +389,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
"""
persist_events_store = self.hs.get_datastores().persist_events
+ assert persist_events_store is not None
for e in events:
e.internal_metadata.stream_ordering = self._next_stream_ordering
@@ -397,6 +398,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
def _persist(txn: LoggingTransaction) -> None:
# We need to persist the events to the events and state_events
# tables.
+ assert persist_events_store is not None
persist_events_store._store_event_txn(
txn,
[(e, EventContext(self.hs.get_storage_controllers())) for e in events],
@@ -415,7 +417,6 @@ class EventChainStoreTestCase(HomeserverTestCase):
def fetch_chains(
self, events: List[EventBase]
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
-
# Fetch the map from event ID -> (chain ID, sequence number)
rows = self.get_success(
self.store.db_pool.simple_select_many_batch(
@@ -490,7 +491,6 @@ class LinkMapTestCase(unittest.TestCase):
class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
-
servlets = [
admin.register_servlets,
room.register_servlets,
@@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_prev_events_for_room(room_id)
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_handler.create_event(
self.requester,
{
@@ -535,14 +535,17 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
prev_event_ids=latest_event_ids,
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
)
)
- state1 = set(self.get_success(context.get_current_state_ids()).values())
+ state_ids1 = self.get_success(context.get_current_state_ids())
+ assert state_ids1 is not None
+ state1 = set(state_ids1.values())
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_handler.create_event(
self.requester,
{
@@ -555,12 +558,15 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
prev_event_ids=latest_event_ids,
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
)
)
- state2 = set(self.get_success(context.get_current_state_ids()).values())
+ state_ids2 = self.get_success(context.get_current_state_ids())
+ assert state_ids2 is not None
+ state2 = set(state_ids2.values())
# Delete the chain cover info.
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 7fd3e01364..3e1984c15c 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -54,6 +54,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
+ persist_events = hs.get_datastores().persist_events
+ assert persist_events is not None
+ self.persist_events = persist_events
def test_get_prev_events_for_room(self) -> None:
room_id = "@ROOM:local"
@@ -226,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -445,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
# Insert all events apart from 'B'
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -464,7 +467,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False},
)
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
)
@@ -669,7 +672,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
complete_event_dict_map: Dict[str, JsonDict] = {}
stream_ordering = 0
- for (event_id, prev_event_ids) in event_graph.items():
+ for event_id, prev_event_ids in event_graph.items():
depth = depth_map[event_id]
complete_event_dict_map[event_id] = {
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a91411168c..6897addbd3 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -33,8 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
events = [(3, 2), (6, 2), (4, 6)]
for event_count, extrems in events:
- info, _ = self.get_success(room_creator.create_room(requester, {}))
- room_id = info["room_id"]
+ room_id, _, _ = self.get_success(room_creator.create_room(requester, {}))
last_event = None
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 76c06a9d1e..aa19c3bd30 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -774,7 +774,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(r, 3)
# add a bunch of dummy events to the events table
- for (stream_ordering, ts) in (
+ for stream_ordering, ts in (
(3, 110),
(4, 120),
(5, 120),
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 05661a537d..e67dd0589d 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -40,7 +40,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler()
- self._persistence = self.hs.get_storage_controllers().persistence
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persistence = persistence
self._state_storage_controller = self.hs.get_storage_controllers().state
self.store = self.hs.get_datastores().main
@@ -374,7 +376,9 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler()
- self._persistence = self.hs.get_storage_controllers().persistence
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persistence = persistence
self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self) -> None:
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index aa4b5bd3b1..ba68171ad7 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -16,8 +16,6 @@ import signedjson.key
import signedjson.types
import unpaddedbase64
-from twisted.internet.defer import Deferred
-
from synapse.storage.keys import FetchKeyResult
import tests.unittest
@@ -44,20 +42,26 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:KEY_ID_2"
- d = store.store_server_verify_keys(
- "from_server",
- 10,
- [
- ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
- ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
- ],
+ self.get_success(
+ store.store_server_verify_keys(
+ "from_server",
+ 10,
+ [
+ ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
)
- self.get_success(d)
- d = store.get_server_verify_keys(
- [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
+ res = self.get_success(
+ store.get_server_verify_keys(
+ [
+ ("server1", key_id_1),
+ ("server1", key_id_2),
+ ("server1", "ed25519:key3"),
+ ]
+ )
)
- res = self.get_success(d)
self.assertEqual(len(res.keys()), 3)
res1 = res[("server1", key_id_1)]
@@ -82,18 +86,20 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"
- d = store.store_server_verify_keys(
- "from_server",
- 0,
- [
- ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
- ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
- ],
+ self.get_success(
+ store.store_server_verify_keys(
+ "from_server",
+ 0,
+ [
+ ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
)
- self.get_success(d)
- d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- res = self.get_success(d)
+ res = self.get_success(
+ store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+ )
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
@@ -105,9 +111,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit
- res = store.get_server_verify_keys([("srv1", key_id_1)])
- if isinstance(res, Deferred):
- res = self.successResultOf(res)
+ res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)]))
self.assertEqual(len(res.keys()), 1)
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
@@ -119,8 +123,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.get_success(d)
- d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- res = self.get_success(d)
+ res = self.get_success(
+ store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+ )
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 010cc74c31..857e2caf2e 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase
class PurgeTests(HomeserverTestCase):
-
user_id = "@red:server"
servlets = [room.register_servlets]
@@ -112,7 +111,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id, "m.room.create", ""
)
)
- self.assertIsNotNone(create_event)
+ assert create_event is not None
# Purge everything before this topological token
self.get_success(
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index d8d84152dc..1b52eef23f 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -37,9 +37,9 @@ class ReceiptTestCase(HomeserverTestCase):
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()
- self.persist_event_storage_controller = (
- self.hs.get_storage_controllers().persistence
- )
+ persist_event_storage_controller = self.hs.get_storage_controllers().persistence
+ assert persist_event_storage_controller is not None
+ self.persist_event_storage_controller = persist_event_storage_controller
# Create a test user
self.ourUser = UserID.from_string(OUR_USER_ID)
@@ -50,12 +50,14 @@ class ReceiptTestCase(HomeserverTestCase):
self.otherRequester = create_requester(self.otherUser)
# Create a test room
- info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
- self.room_id1 = info["room_id"]
+ self.room_id1, _, _ = self.get_success(
+ self.room_creator.create_room(self.ourRequester, {})
+ )
# Create a second test room
- info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
- self.room_id2 = info["room_id"]
+ self.room_id2, _, _ = self.get_success(
+ self.room_creator.create_room(self.ourRequester, {})
+ )
# Join the second user to the first room
memberEvent, memberEventContext = self.get_success(
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index df4740f9d9..0100f7da14 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata
- event_1, context_1 = self.get_success(
+ event_1, unpersisted_context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
+ context_1 = self.get_success(unpersisted_context_1.persist(event_1))
+
self.get_success(self._persistence.persist_event(event_1, context_1))
- event_2, context_2 = self.get_success(
+ event_2, unpersisted_context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
+
+ context_2 = self.get_success(unpersisted_context_2.persist(event_2))
self.get_success(self._persistence.persist_event(event_2, context_2))
# fetch one of the redactions
@@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- redaction_event, context = self.get_success(
+ redaction_event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(redaction_event))
+
self.get_success(self._persistence.persist_event(redaction_event, context))
# Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 14d872514d..f183c38477 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -119,7 +119,6 @@ class EventSearchInsertionTest(HomeserverTestCase):
"content": {"msgtype": "m.text", "body": 2},
"room_id": room_id,
"sender": user_id,
- "depth": prev_event.depth + 1,
"prev_events": prev_event_ids,
"origin_server_ts": self.clock.time_msec(),
}
@@ -134,7 +133,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_state_map,
for_verification=False,
),
- depth=event_dict["depth"],
+ depth=prev_event.depth + 1,
)
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 8794401823..f4c4661aaf 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -27,7 +27,6 @@ from tests.test_utils import event_injection
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
-
servlets = [
login.register_servlets,
register_servlets_for_client_rest_resource,
@@ -35,7 +34,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None: # type: ignore[override]
-
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastores().main
@@ -48,7 +46,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.u_charlie = UserID.from_string("@charlie:elsewhere")
def test_one_member(self) -> None:
-
# Alice creates the room, and is automatically joined
self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index bad7f0bc60..62aed6af0a 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))
@@ -240,7 +242,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -257,7 +259,7 @@ class StateStoreTestCase(HomeserverTestCase):
state_dict,
)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -270,7 +272,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -287,7 +289,7 @@ class StateStoreTestCase(HomeserverTestCase):
state_dict,
)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -307,7 +309,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -325,7 +327,7 @@ class StateStoreTestCase(HomeserverTestCase):
state_dict,
)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -339,7 +341,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -390,7 +392,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -402,7 +404,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertDictEqual({}, state_dict)
room_id = self.room.to_string()
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -415,7 +417,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -426,7 +428,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -445,7 +447,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -457,7 +459,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -471,7 +473,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -483,7 +485,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -494,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+ def test_batched_state_group_storing(self) -> None:
+ creation_event = self.inject_state_event(
+ self.room, self.u_alice, EventTypes.Create, "", {}
+ )
+ state_to_event = self.get_success(
+ self.storage.state.get_state_groups(
+ self.room.to_string(), [creation_event.event_id]
+ )
+ )
+ current_state_group = list(state_to_event.keys())[0]
+
+ # create some unpersisted events and event contexts to store against room
+ events_and_context = []
+ builder = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Name,
+ "sender": self.u_alice.to_string(),
+ "state_key": "",
+ "room_id": self.room.to_string(),
+ "content": {"name": "first rename of room"},
+ },
+ )
+
+ event1, unpersisted_context1 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
+ )
+ events_and_context.append((event1, unpersisted_context1))
+
+ builder2 = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.JoinRules,
+ "sender": self.u_alice.to_string(),
+ "state_key": "",
+ "room_id": self.room.to_string(),
+ "content": {"join_rule": "private"},
+ },
+ )
+
+ event2, unpersisted_context2 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder2)
+ )
+ events_and_context.append((event2, unpersisted_context2))
+
+ builder3 = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Message,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room.to_string(),
+ "content": {"body": "hello from event 3", "msgtype": "m.text"},
+ },
+ )
+
+ event3, unpersisted_context3 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder3)
+ )
+ events_and_context.append((event3, unpersisted_context3))
+
+ builder4 = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.JoinRules,
+ "sender": self.u_alice.to_string(),
+ "state_key": "",
+ "room_id": self.room.to_string(),
+ "content": {"join_rule": "public"},
+ },
+ )
+
+ event4, unpersisted_context4 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder4)
+ )
+ events_and_context.append((event4, unpersisted_context4))
+
+ processed_events_and_context = self.get_success(
+ self.hs.get_datastores().state.store_state_deltas_for_batched(
+ events_and_context, self.room.to_string(), current_state_group
+ )
+ )
+
+ # check that only state events are in state_groups, and all state events are in state_groups
+ res = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups",
+ keyvalues=None,
+ retcols=("event_id",),
+ )
+ )
+
+ events = []
+ for result in res:
+ self.assertNotIn(event3.event_id, result)
+ events.append(result.get("event_id"))
+
+ for event, _ in processed_events_and_context:
+ if event.is_state():
+ self.assertIn(event.event_id, events)
+
+ # check that each unique state has state group in state_groups_state and that the
+ # type/state key is correct, and check that each state event's state group
+ # has an entry and prev event in state_group_edges
+ for event, context in processed_events_and_context:
+ if event.is_state():
+ state = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups_state",
+ keyvalues={"state_group": context.state_group_after_event},
+ retcols=("type", "state_key"),
+ )
+ )
+ self.assertEqual(event.type, state[0].get("type"))
+ self.assertEqual(event.state_key, state[0].get("state_key"))
+
+ groups = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_group_edges",
+ keyvalues={"state_group": str(context.state_group_after_event)},
+ retcols=("*",),
+ )
+ )
+ self.assertEqual(
+ context.state_group_before_event, groups[0].get("prev_state_group")
+ )
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index bc090ebce0..05dc4f64b8 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -16,7 +16,7 @@ from typing import List
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.filtering import Filter
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -128,7 +128,7 @@ class PaginationTestCase(HomeserverTestCase):
room_id=self.room_id,
from_key=self.from_token.room_key,
to_key=None,
- direction="f",
+ direction=Direction.FORWARDS,
limit=10,
event_filter=Filter(self.hs, filter),
)
diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
index ba53c22818..19da8a9b09 100644
--- a/tests/storage/test_unsafe_locale.py
+++ b/tests/storage/test_unsafe_locale.py
@@ -14,6 +14,7 @@
from unittest.mock import MagicMock, patch
from synapse.storage.database import make_conn
+from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IncorrectDatabaseSetup
from tests.unittest import HomeserverTestCase
@@ -38,6 +39,7 @@ class UnsafeLocaleTest(HomeserverTestCase):
def test_safe_locale(self) -> None:
database = self.hs.get_datastores().databases[0]
+ assert isinstance(database.engine, PostgresEngine)
db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
with db_conn.cursor() as txn:
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index f1ca523d23..8c72aa1722 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -25,6 +25,11 @@ from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.background_updates import _BackgroundUpdateHandler
+from synapse.storage.databases.main import user_directory
+from synapse.storage.databases.main.user_directory import (
+ _parse_words_with_icu,
+ _parse_words_with_regex,
+)
from synapse.storage.roommember import ProfileInfo
from synapse.util import Clock
@@ -42,7 +47,7 @@ ALICE = "@alice:a"
BOB = "@bob:b"
BOBBY = "@bobby:a"
# The localpart isn't 'Bela' on purpose so we can test looking up display names.
-BELA = "@somenickname:a"
+BELA = "@somenickname:example.org"
class GetUserDirectoryTables:
@@ -423,6 +428,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
class UserDirectoryStoreTestCase(HomeserverTestCase):
+ use_icu = False
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
@@ -434,6 +441,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
+ self._restore_use_icu = user_directory.USE_ICU
+ user_directory.USE_ICU = self.use_icu
+
+ def tearDown(self) -> None:
+ user_directory.USE_ICU = self._restore_use_icu
+
def test_search_user_dir(self) -> None:
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
@@ -478,6 +491,159 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
{"user_id": BELA, "display_name": "Bela", "avatar_url": None},
)
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_start_of_user_id(self) -> None:
+ """Tests that a user can look up another user by searching for the start
+ of their user ID.
+ """
+ r = self.get_success(self.store.search_user_dir(ALICE, "somenickname:exa", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
+
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_ascii_case_insensitivity(self) -> None:
+ """Tests that a user can look up another user by searching for their name in a
+ different case.
+ """
+ CHARLIE = "@someuser:example.org"
+ self.get_success(
+ self.store.update_profile_in_user_dir(CHARLIE, "Charlie", None)
+ )
+
+ r = self.get_success(self.store.search_user_dir(ALICE, "cHARLIE", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": CHARLIE, "display_name": "Charlie", "avatar_url": None},
+ )
+
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_unicode_case_insensitivity(self) -> None:
+ """Tests that a user can look up another user by searching for their name in a
+ different case.
+ """
+ IVAN = "@someuser:example.org"
+ self.get_success(self.store.update_profile_in_user_dir(IVAN, "Иван", None))
+
+ r = self.get_success(self.store.search_user_dir(ALICE, "иВАН", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": IVAN, "display_name": "Иван", "avatar_url": None},
+ )
+
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_dotted_dotless_i_case_insensitivity(self) -> None:
+ """Tests that a user can look up another user by searching for their name in a
+ different case, when their name contains dotted or dotless "i"s.
+
+ Some languages have dotted and dotless versions of "i", which are considered to
+ be different letters: i <-> İ, ı <-> I. To make things difficult, they reuse the
+ ASCII "i" and "I" code points, despite having different lowercase / uppercase
+ forms.
+ """
+ USER = "@someuser:example.org"
+
+ expected_matches = [
+ # (search_term, display_name)
+ # A search for "i" should match "İ".
+ ("iiiii", "İİİİİ"),
+ # A search for "I" should match "ı".
+ ("IIIII", "ııııı"),
+ # A search for "ı" should match "I".
+ ("ııııı", "IIIII"),
+ # A search for "İ" should match "i".
+ ("İİİİİ", "iiiii"),
+ ]
+
+ for search_term, display_name in expected_matches:
+ self.get_success(
+ self.store.update_profile_in_user_dir(USER, display_name, None)
+ )
+
+ r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(
+ 1,
+ len(r["results"]),
+ f"searching for {search_term!r} did not match {display_name!r}",
+ )
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": USER, "display_name": display_name, "avatar_url": None},
+ )
+
+ # We don't test for negative matches, to allow implementations that consider all
+ # the i variants to be the same.
+
+ test_search_user_dir_dotted_dotless_i_case_insensitivity.skip = "not supported" # type: ignore
+
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_unicode_normalization(self) -> None:
+ """Tests that a user can look up another user by searching for their name with
+ either composed or decomposed accents.
+ """
+ AMELIE = "@someuser:example.org"
+
+ expected_matches = [
+ # (search_term, display_name)
+ ("Ame\u0301lie", "Amélie"),
+ ("Amélie", "Ame\u0301lie"),
+ ]
+
+ for search_term, display_name in expected_matches:
+ self.get_success(
+ self.store.update_profile_in_user_dir(AMELIE, display_name, None)
+ )
+
+ r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(
+ 1,
+ len(r["results"]),
+ f"searching for {search_term!r} did not match {display_name!r}",
+ )
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": AMELIE, "display_name": display_name, "avatar_url": None},
+ )
+
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_accent_insensitivity(self) -> None:
+ """Tests that a user can look up another user by searching for their name
+ without any accents.
+ """
+ AMELIE = "@someuser:example.org"
+ self.get_success(self.store.update_profile_in_user_dir(AMELIE, "Amélie", None))
+
+ r = self.get_success(self.store.search_user_dir(ALICE, "amelie", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": AMELIE, "display_name": "Amélie", "avatar_url": None},
+ )
+
+ # It may be desirable for "é"s in search terms to not match plain "e"s and we
+ # really don't want "é"s in search terms to match "e"s with different accents.
+ # But we don't test for this to allow implementations that consider all
+ # "e"-lookalikes to be the same.
+
+ test_search_user_dir_accent_insensitivity.skip = "not supported yet" # type: ignore
+
+
+class UserDirectoryStoreTestCaseWithIcu(UserDirectoryStoreTestCase):
+ use_icu = True
+
+ if not icu:
+ skip = "Requires PyICU"
+
class UserDirectoryICUTestCase(HomeserverTestCase):
if not icu:
@@ -513,3 +679,33 @@ class UserDirectoryICUTestCase(HomeserverTestCase):
r["results"][0],
{"user_id": ALICE, "display_name": display_name, "avatar_url": None},
)
+
+ def test_icu_word_boundary_punctuation(self) -> None:
+ """
+ Tests the behaviour of punctuation with the ICU tokeniser.
+
+ Seems to depend on underlying version of ICU.
+ """
+
+ # Note: either tokenisation is fine, because Postgres actually splits
+ # words itself afterwards.
+ self.assertIn(
+ _parse_words_with_icu("lazy'fox jumped:over the.dog"),
+ (
+ # ICU 66 on Ubuntu 20.04
+ ["lazy'fox", "jumped", "over", "the", "dog"],
+ # ICU 70 on Ubuntu 22.04
+ ["lazy'fox", "jumped:over", "the.dog"],
+ # pyicu 2.10.2 on Alpine edge / macOS
+ ["lazy'fox", "jumped", "over", "the.dog"],
+ ),
+ )
+
+ def test_regex_word_boundary_punctuation(self) -> None:
+ """
+ Tests the behaviour of punctuation with the non-ICU tokeniser
+ """
+ self.assertEqual(
+ _parse_words_with_regex("lazy'fox jumped:over the.dog"),
+ ["lazy", "fox", "jumped", "over", "the", "dog"],
+ )
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 31546ea52b..a248f1d277 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -21,10 +21,10 @@ from . import unittest
class DistributorTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.dist = Distributor()
- def test_signal_dispatch(self):
+ def test_signal_dispatch(self) -> None:
self.dist.declare("alert")
observer = Mock()
@@ -33,7 +33,7 @@ class DistributorTestCase(unittest.TestCase):
self.dist.fire("alert", 1, 2, 3)
observer.assert_called_with(1, 2, 3)
- def test_signal_catch(self):
+ def test_signal_catch(self) -> None:
self.dist.declare("alarm")
observers = [Mock() for i in (1, 2)]
@@ -51,7 +51,7 @@ class DistributorTestCase(unittest.TestCase):
self.assertEqual(mock_logger.warning.call_count, 1)
self.assertIsInstance(mock_logger.warning.call_args[0][0], str)
- def test_signal_prereg(self):
+ def test_signal_prereg(self) -> None:
observer = Mock()
self.dist.observe("flare", observer)
@@ -60,8 +60,8 @@ class DistributorTestCase(unittest.TestCase):
observer.assert_called_with(4, 5)
- def test_signal_undeclared(self):
- def code():
+ def test_signal_undeclared(self) -> None:
+ def code() -> None:
self.dist.fire("notification")
self.assertRaises(KeyError, code)
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 0a7937f1cc..2860564afc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -31,13 +31,13 @@ from tests.test_utils import get_awaitable_result
class _StubEventSourceStore:
"""A stub implementation of the EventSourceStore"""
- def __init__(self):
+ def __init__(self) -> None:
self._store: Dict[str, EventBase] = {}
- def add_event(self, event: EventBase):
+ def add_event(self, event: EventBase) -> None:
self._store[event.event_id] = event
- def add_events(self, events: Iterable[EventBase]):
+ def add_events(self, events: Iterable[EventBase]) -> None:
for event in events:
self._store[event.event_id] = event
@@ -59,7 +59,7 @@ class _StubEventSourceStore:
class EventAuthTestCase(unittest.TestCase):
- def test_rejected_auth_events(self):
+ def test_rejected_auth_events(self) -> None:
"""
Events that refer to rejected events in their auth events are rejected
"""
@@ -109,7 +109,7 @@ class EventAuthTestCase(unittest.TestCase):
)
)
- def test_create_event_with_prev_events(self):
+ def test_create_event_with_prev_events(self) -> None:
"""A create event with prev_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -150,7 +150,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event)
)
- def test_duplicate_auth_events(self):
+ def test_duplicate_auth_events(self) -> None:
"""Events with duplicate auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -196,7 +196,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event2)
)
- def test_unexpected_auth_events(self):
+ def test_unexpected_auth_events(self) -> None:
"""Events with excess auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -236,7 +236,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event)
)
- def test_random_users_cannot_send_state_before_first_pl(self):
+ def test_random_users_cannot_send_state_before_first_pl(self) -> None:
"""
Check that, before the first PL lands, the creator is the only user
that can send a state event.
@@ -263,7 +263,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_state_default_level(self):
+ def test_state_default_level(self) -> None:
"""
Check that users above the state_default level can send state and
those below cannot
@@ -298,7 +298,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_alias_event(self):
+ def test_alias_event(self) -> None:
"""Alias events have special behavior up through room version 6."""
creator = "@creator:example.com"
other = "@other:example.com"
@@ -333,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_msc2432_alias_event(self):
+ def test_msc2432_alias_event(self) -> None:
"""After MSC2432, alias events have no special behavior."""
creator = "@creator:example.com"
other = "@other:example.com"
@@ -366,7 +366,9 @@ class EventAuthTestCase(unittest.TestCase):
)
@parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)])
- def test_notifications(self, room_version: RoomVersion, allow_modification: bool):
+ def test_notifications(
+ self, room_version: RoomVersion, allow_modification: bool
+ ) -> None:
"""
Notifications power levels get checked due to MSC2209.
"""
@@ -395,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase):
with self.assertRaises(AuthError):
event_auth.check_state_dependent_auth_rules(pl_event, auth_events)
- def test_join_rules_public(self):
+ def test_join_rules_public(self) -> None:
"""
Test joining a public room.
"""
@@ -460,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events.values(),
)
- def test_join_rules_invite(self):
+ def test_join_rules_invite(self) -> None:
"""
Test joining an invite only room.
"""
@@ -835,7 +837,7 @@ def _power_levels_event(
)
-def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
+def _alias_event(room_version: RoomVersion, sender: str, **kwargs: Any) -> EventBase:
data = {
"room_id": TEST_ROOM_ID,
**_maybe_get_event_id_dict_for_room_version(room_version),
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 80e5c590d8..46d2f99eac 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -12,53 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Collection, List, Optional, Union
from unittest.mock import Mock
-from twisted.internet.defer import succeed
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import FederationError
-from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import event_from_pdu_json
+from synapse.handlers.device import DeviceListUpdater
+from synapse.http.types import QueryParams
from synapse.logging.context import LoggingContext
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
-from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase):
- def setUp(self):
-
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock()
- self.reactor = ThreadedMemoryReactorClock()
- self.hs_clock = Clock(self.reactor)
- self.homeserver = setup_test_homeserver(
- self.addCleanup,
- federation_http_client=self.http_client,
- clock=self.hs_clock,
- reactor=self.reactor,
- )
+ return self.setup_test_homeserver(federation_http_client=self.http_client)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
user_id = UserID("us", "test")
our_user = create_requester(user_id)
- room_creator = self.homeserver.get_room_creation_handler()
+ room_creator = self.hs.get_room_creation_handler()
self.room_id = self.get_success(
room_creator.create_room(
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
)
- )[0]["room_id"]
+ )[0]
- self.store = self.homeserver.get_datastores().main
+ self.store = self.hs.get_datastores().main
# Figure out what the most recent event is
most_recent = self.get_success(
- self.homeserver.get_datastores().main.get_latest_event_ids_in_room(
- self.room_id
- )
+ self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)[0]
join_event = make_event_from_dict(
@@ -78,17 +73,23 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- self.handler = self.homeserver.get_federation_handler()
- federation_event_handler = self.homeserver.get_federation_event_handler()
+ self.handler = self.hs.get_federation_handler()
+ federation_event_handler = self.hs.get_federation_event_handler()
- async def _check_event_auth(origin, event, context):
+ async def _check_event_auth(
+ origin: Optional[str], event: EventBase, context: EventContext
+ ) -> None:
pass
- federation_event_handler._check_event_auth = _check_event_auth
- self.client = self.homeserver.get_federation_client()
- self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
- lambda dest, pdus, **k: succeed(pdus)
- )
+ federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment]
+ self.client = self.hs.get_federation_client()
+
+ async def _check_sigs_and_hash_for_pulled_events_and_fetch(
+ dest: str, pdus: Collection[EventBase], room_version: RoomVersion
+ ) -> List[EventBase]:
+ return list(pdus)
+
+ self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
# Send the join, it should return None (which is not an error)
self.assertEqual(
@@ -104,16 +105,25 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
"$join:test.serv",
)
- def test_cant_hide_direct_ancestors(self):
+ def test_cant_hide_direct_ancestors(self) -> None:
"""
If you send a message, you must be able to provide the direct
prev_events that said event references.
"""
- async def post_json(destination, path, data, headers=None, timeout=0):
+ async def post_json(
+ destination: str,
+ path: str,
+ data: Optional[JsonDict] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryParams] = None,
+ ) -> Union[JsonDict, list]:
# If it asks us for new missing events, give them NOTHING
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
return {"events": []}
+ return {}
self.http_client.post_json = post_json
@@ -138,7 +148,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- federation_event_handler = self.homeserver.get_federation_event_handler()
+ federation_event_handler = self.hs.get_federation_event_handler()
with LoggingContext("test-context"):
failure = self.get_failure(
federation_event_handler.on_receive_pdu("test.serv", lying_event),
@@ -158,7 +168,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
self.assertEqual(extrem[0], "$join:test.serv")
- def test_retry_device_list_resync(self):
+ def test_retry_device_list_resync(self) -> None:
"""Tests that device lists are marked as stale if they couldn't be synced, and
that stale device lists are retried periodically.
"""
@@ -171,24 +181,27 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# When this function is called, increment the number of resync attempts (only if
# we're querying devices for the right user ID), then raise a
# NotRetryingDestination error to fail the resync gracefully.
- def query_user_devices(destination, user_id):
+ def query_user_devices(
+ destination: str, user_id: str, timeout: int = 30000
+ ) -> JsonDict:
if user_id == remote_user_id:
self.resync_attempts += 1
raise NotRetryingDestination(0, 0, destination)
# Register the mock on the federation client.
- federation_client = self.homeserver.get_federation_client()
- federation_client.query_user_devices = Mock(side_effect=query_user_devices)
+ federation_client = self.hs.get_federation_client()
+ federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment]
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
- store = self.homeserver.get_datastores().main
+ store = self.hs.get_datastores().main
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
- device_list_updater = self.homeserver.get_device_handler().device_list_updater
+ device_list_updater = self.hs.get_device_handler().device_list_updater
+ assert isinstance(device_list_updater, DeviceListUpdater)
self.get_success(
device_list_updater.incoming_device_list_update(
origin=remote_origin,
@@ -218,7 +231,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.reactor.advance(30)
self.assertEqual(self.resync_attempts, 2)
- def test_cross_signing_keys_retry(self):
+ def test_cross_signing_keys_retry(self) -> None:
"""Tests that resyncing a device list correctly processes cross-signing keys from
the remote server.
"""
@@ -227,8 +240,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
# Register mock device list retrieval on the federation client.
- federation_client = self.homeserver.get_federation_client()
- federation_client.query_user_devices = Mock(
+ federation_client = self.hs.get_federation_client()
+ federation_client.query_user_devices = Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"user_id": remote_user_id,
@@ -252,7 +265,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
# Resync the device list.
- device_handler = self.homeserver.get_device_handler()
+ device_handler = self.hs.get_device_handler()
self.get_success(
device_handler.device_list_updater.user_device_resync(remote_user_id),
)
@@ -261,16 +274,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
keys = self.get_success(
self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
)
- self.assertTrue(remote_user_id in keys)
+ self.assertIn(remote_user_id, keys)
+ key = keys[remote_user_id]
+ assert key is not None
# Check that the master key is the one returned by the mock.
- master_key = keys[remote_user_id]["master"]
+ master_key = key["master"]
self.assertEqual(len(master_key["keys"]), 1)
self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
self.assertTrue(remote_master_key in master_key["keys"].values())
# Check that the self-signing key is the one returned by the mock.
- self_signing_key = keys[remote_user_id]["self_signing"]
+ self_signing_key = key["self_signing"]
self.assertEqual(len(self_signing_key["keys"]), 1)
self.assertTrue(
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
@@ -279,7 +294,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
class StripUnsignedFromEventsTestCase(unittest.TestCase):
- def test_strip_unauthorized_unsigned_values(self):
+ def test_strip_unauthorized_unsigned_values(self) -> None:
event1 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
@@ -296,7 +311,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
# Make sure unauthorized fields are stripped from unsigned
self.assertNotIn("more warez", filtered_event.unsigned)
- def test_strip_event_maintains_allowed_fields(self):
+ def test_strip_event_maintains_allowed_fields(self) -> None:
event2 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
@@ -323,7 +338,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
self.assertIn("invite_room_state", filtered_event2.unsigned)
self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
- def test_strip_event_removes_fields_based_on_event_type(self):
+ def test_strip_event_removes_fields_based_on_event_type(self) -> None:
event3 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
diff --git a/tests/test_mau.py b/tests/test_mau.py
index f14fcb7db9..ff21098a59 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -14,12 +14,17 @@
"""Tests REST events for /rooms paths."""
-from typing import List
+from typing import List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
from synapse.rest.client import register, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
@@ -27,10 +32,9 @@ from tests.utils import default_config
class TestMauLimit(unittest.HomeserverTestCase):
-
servlets = [register.register_servlets, sync.register_servlets]
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = default_config("test")
config.update(
@@ -53,10 +57,12 @@ class TestMauLimit(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
- def test_simple_deny_mau(self):
+ def test_simple_deny_mau(self) -> None:
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -75,7 +81,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def test_as_ignores_mau(self):
+ def test_as_ignores_mau(self) -> None:
"""Test that application services can still create users when the MAU
limit has been reached. This only works when application service
user ip tracking is disabled.
@@ -113,7 +119,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.create_user("as_kermit4", token=as_token, appservice=True)
- def test_allowed_after_a_month_mau(self):
+ def test_allowed_after_a_month_mau(self) -> None:
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -132,7 +138,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.do_sync_for_user(token3)
@override_config({"mau_trial_days": 1})
- def test_trial_delay(self):
+ def test_trial_delay(self) -> None:
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -165,7 +171,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@override_config({"mau_trial_days": 1})
- def test_trial_users_cant_come_back(self):
+ def test_trial_users_cant_come_back(self) -> None:
self.hs.config.server.mau_trial_days = 1
# We should be able to register more than the limit initially
@@ -216,7 +222,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# max_mau_value should not matter
{"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
)
- def test_tracked_but_not_limited(self):
+ def test_tracked_but_not_limited(self) -> None:
# Simply being able to create 2 users indicates that the
# limit was not reached.
token1 = self.create_user("kermit1")
@@ -236,10 +242,10 @@ class TestMauLimit(unittest.HomeserverTestCase):
"mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2},
}
)
- def test_as_trial_days(self):
+ def test_as_trial_days(self) -> None:
user_tokens: List[str] = []
- def advance_time_and_sync():
+ def advance_time_and_sync() -> None:
self.reactor.advance(24 * 60 * 61)
for token in user_tokens:
self.do_sync_for_user(token)
@@ -300,7 +306,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
},
)
- def create_user(self, localpart, token=None, appservice=False):
+ def create_user(
+ self, localpart: str, token: Optional[str] = None, appservice: bool = False
+ ) -> str:
request_data = {
"username": localpart,
"password": "monkey",
@@ -326,7 +334,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
return access_token
- def do_sync_for_user(self, token):
+ def do_sync_for_user(self, token: str) -> None:
channel = self.make_request("GET", "/sync", access_token=token)
if channel.code != 200:
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index cc1a98f1c4..3f899b0d91 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -33,7 +33,7 @@ class PhoneHomeStatsTestCase(HomeserverTestCase):
If time doesn't move, don't error out.
"""
past_stats = [
- (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))
+ (int(self.hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
]
stats: JsonDict = {}
self.get_success(phone_stats_home(self.hs, stats, past_stats))
diff --git a/tests/test_rust.py b/tests/test_rust.py
index 55d8b6b28c..67443b6280 100644
--- a/tests/test_rust.py
+++ b/tests/test_rust.py
@@ -6,6 +6,6 @@ from tests import unittest
class RustTestCase(unittest.TestCase):
"""Basic tests to ensure that we can call into Rust code."""
- def test_basic(self):
+ def test_basic(self) -> None:
result = sum_as_string(1, 2)
self.assertEqual("3", result)
diff --git a/tests/test_state.py b/tests/test_state.py
index 504530b49a..b20a26e1ff 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -11,7 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Collection, Dict, List, Optional, cast
+from typing import (
+ Any,
+ Collection,
+ Dict,
+ Generator,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
from unittest.mock import Mock
from twisted.internet import defer
@@ -19,9 +31,11 @@ from twisted.internet import defer
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
+from synapse.types import MutableStateMap, StateMap
+from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.macaroons import MacaroonGenerator
@@ -33,14 +47,14 @@ _next_event_id = 1000
def create_event(
- name=None,
- type=None,
- state_key=None,
- depth=2,
- event_id=None,
- prev_events: Optional[List[str]] = None,
- **kwargs,
-):
+ name: Optional[str] = None,
+ type: Optional[str] = None,
+ state_key: Optional[str] = None,
+ depth: int = 2,
+ event_id: Optional[str] = None,
+ prev_events: Optional[List[Tuple[str, dict]]] = None,
+ **kwargs: Any,
+) -> EventBase:
global _next_event_id
if not event_id:
@@ -67,21 +81,21 @@ def create_event(
d.update(kwargs)
- event = make_event_from_dict(d)
-
- return event
+ return make_event_from_dict(d)
class _DummyStore:
- def __init__(self):
- self._event_to_state_group = {}
- self._group_to_state = {}
+ def __init__(self) -> None:
+ self._event_to_state_group: Dict[str, int] = {}
+ self._group_to_state: Dict[int, MutableStateMap[str]] = {}
- self._event_id_to_event = {}
+ self._event_id_to_event: Dict[str, EventBase] = {}
self._next_group = 1
- async def get_state_groups_ids(self, room_id, event_ids):
+ async def get_state_groups_ids(
+ self, room_id: str, event_ids: Collection[str]
+ ) -> Dict[int, MutableStateMap[str]]:
groups = {}
for event_id in event_ids:
group = self._event_to_state_group.get(event_id)
@@ -90,16 +104,25 @@ class _DummyStore:
return groups
- async def get_state_ids_for_group(self, state_group, state_filter=None):
+ async def get_state_ids_for_group(
+ self, state_group: int, state_filter: Optional[StateFilter] = None
+ ) -> MutableStateMap[str]:
return self._group_to_state[state_group]
async def store_state_group(
- self, event_id, room_id, prev_group, delta_ids, current_state_ids
- ):
+ self,
+ event_id: str,
+ room_id: str,
+ prev_group: Optional[int],
+ delta_ids: Optional[StateMap[str]],
+ current_state_ids: Optional[StateMap[str]],
+ ) -> int:
state_group = self._next_group
self._next_group += 1
if current_state_ids is None:
+ assert prev_group is not None
+ assert delta_ids is not None
current_state_ids = dict(self._group_to_state[prev_group])
current_state_ids.update(delta_ids)
@@ -107,7 +130,9 @@ class _DummyStore:
return state_group
- async def get_events(self, event_ids, **kwargs):
+ async def get_events(
+ self, event_ids: Collection[str], **kwargs: Any
+ ) -> Dict[str, EventBase]:
return {
e_id: self._event_id_to_event[e_id]
for e_id in event_ids
@@ -119,31 +144,36 @@ class _DummyStore:
) -> Dict[str, bool]:
return {e: False for e in event_ids}
- async def get_state_group_delta(self, name):
+ async def get_state_group_delta(
+ self, name: str
+ ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
return None, None
- def register_events(self, events):
+ def register_events(self, events: Iterable[EventBase]) -> None:
for e in events:
self._event_id_to_event[e.event_id] = e
- def register_event_context(self, event, context):
+ def register_event_context(self, event: EventBase, context: EventContext) -> None:
+ assert context.state_group is not None
self._event_to_state_group[event.event_id] = context.state_group
- def register_event_id_state_group(self, event_id, state_group):
+ def register_event_id_state_group(self, event_id: str, state_group: int) -> None:
self._event_to_state_group[event_id] = state_group
- async def get_room_version_id(self, room_id):
+ async def get_room_version_id(self, room_id: str) -> str:
return RoomVersions.V1.identifier
async def get_state_group_for_events(
- self, event_ids, await_full_state: bool = True
- ):
+ self, event_ids: Collection[str], await_full_state: bool = True
+ ) -> Dict[str, int]:
res = {}
for event in event_ids:
res[event] = self._event_to_state_group[event]
return res
- async def get_state_for_groups(self, groups):
+ async def get_state_for_groups(
+ self, groups: Collection[int]
+ ) -> Dict[int, MutableStateMap[str]]:
res = {}
for group in groups:
state = self._group_to_state[group]
@@ -152,21 +182,21 @@ class _DummyStore:
class DictObj(dict):
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: Any) -> None:
super().__init__(kwargs)
self.__dict__ = self
class Graph:
- def __init__(self, nodes, edges):
- events = {}
- clobbered = set(events.keys())
+ def __init__(self, nodes: Dict[str, DictObj], edges: Dict[str, List[str]]):
+ events: Dict[str, EventBase] = {}
+ clobbered: Set[str] = set()
for event_id, fields in nodes.items():
refs = edges.get(event_id)
if refs:
clobbered.difference_update(refs)
- prev_events = [(r, {}) for r in refs]
+ prev_events: List[Tuple[str, dict]] = [(r, {}) for r in refs]
else:
prev_events = []
@@ -177,15 +207,12 @@ class Graph:
self._leaves = clobbered
self._events = sorted(events.values(), key=lambda e: e.depth)
- def walk(self):
+ def walk(self) -> Iterator[EventBase]:
return iter(self._events)
- def get_leaves(self):
- return (self._events[i] for i in self._leaves)
-
class StateTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.dummy_store = _DummyStore()
storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
hs = Mock(
@@ -220,7 +247,7 @@ class StateTestCase(unittest.TestCase):
self.event_id = 0
@defer.inlineCallbacks
- def test_branch_no_conflict(self):
+ def test_branch_no_conflict(self) -> Generator[defer.Deferred, Any, None]:
graph = Graph(
nodes={
"START": DictObj(
@@ -248,6 +275,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
+ prev_state_ids: StateMap[str]
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertEqual(2, len(prev_state_ids))
@@ -255,7 +283,9 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
@defer.inlineCallbacks
- def test_branch_basic_conflict(self):
+ def test_branch_basic_conflict(
+ self,
+ ) -> Generator["defer.Deferred[object]", Any, None]:
graph = Graph(
nodes={
"START": DictObj(
@@ -280,7 +310,7 @@ class StateTestCase(unittest.TestCase):
self.dummy_store.register_events(graph.walk())
- context_store = {}
+ context_store: Dict[str, EventContext] = {}
for event in graph.walk():
context = yield defer.ensureDeferred(
@@ -294,6 +324,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
+ prev_state_ids: StateMap[str]
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
@@ -301,7 +332,9 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
@defer.inlineCallbacks
- def test_branch_have_banned_conflict(self):
+ def test_branch_have_banned_conflict(
+ self,
+ ) -> Generator["defer.Deferred[object]", Any, None]:
graph = Graph(
nodes={
"START": DictObj(
@@ -338,7 +371,7 @@ class StateTestCase(unittest.TestCase):
self.dummy_store.register_events(graph.walk())
- context_store = {}
+ context_store: Dict[str, EventContext] = {}
for event in graph.walk():
context = yield defer.ensureDeferred(
@@ -353,13 +386,16 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_e = context_store["E"]
+ prev_state_ids: StateMap[str]
prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@defer.inlineCallbacks
- def test_branch_have_perms_conflict(self):
+ def test_branch_have_perms_conflict(
+ self,
+ ) -> Generator["defer.Deferred[object]", Any, None]:
userid1 = "@user_id:example.com"
userid2 = "@user_id2:example.com"
@@ -413,7 +449,7 @@ class StateTestCase(unittest.TestCase):
self.dummy_store.register_events(graph.walk())
- context_store = {}
+ context_store: Dict[str, EventContext] = {}
for event in graph.walk():
context = yield defer.ensureDeferred(
@@ -428,14 +464,17 @@ class StateTestCase(unittest.TestCase):
ctx_b = context_store["B"]
ctx_d = context_store["D"]
+ prev_state_ids: StateMap[str]
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
- def _add_depths(self, nodes, edges):
- def _get_depth(ev):
+ def _add_depths(
+ self, nodes: Dict[str, DictObj], edges: Dict[str, List[str]]
+ ) -> None:
+ def _get_depth(ev: str) -> int:
node = nodes[ev]
if "depth" not in node:
prevs = edges[ev]
@@ -447,7 +486,9 @@ class StateTestCase(unittest.TestCase):
_get_depth(n)
@defer.inlineCallbacks
- def test_annotate_with_old_message(self):
+ def test_annotate_with_old_message(
+ self,
+ ) -> Generator["defer.Deferred[object]", Any, None]:
event = create_event(type="test_message", name="event")
old_state = [
@@ -456,6 +497,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
+ context: EventContext
context = yield defer.ensureDeferred(
self.state.compute_event_context(
event,
@@ -466,9 +508,11 @@ class StateTestCase(unittest.TestCase):
)
)
+ prev_state_ids: StateMap[str]
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
+ current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
@@ -478,7 +522,9 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(context.state_group_before_event, context.state_group)
@defer.inlineCallbacks
- def test_annotate_with_old_state(self):
+ def test_annotate_with_old_state(
+ self,
+ ) -> Generator["defer.Deferred[object]", Any, None]:
event = create_event(type="state", state_key="", name="event")
old_state = [
@@ -487,6 +533,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
+ context: EventContext
context = yield defer.ensureDeferred(
self.state.compute_event_context(
event,
@@ -497,9 +544,11 @@ class StateTestCase(unittest.TestCase):
)
)
+ prev_state_ids: StateMap[str]
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
+ current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
@@ -511,7 +560,9 @@ class StateTestCase(unittest.TestCase):
self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
@defer.inlineCallbacks
- def test_trivial_annotate_message(self):
+ def test_trivial_annotate_message(
+ self,
+ ) -> Generator["defer.Deferred[object]", Any, None]:
prev_event_id = "prev_event_id"
event = create_event(
type="test_message", name="event2", prev_events=[(prev_event_id, {})]
@@ -534,8 +585,10 @@ class StateTestCase(unittest.TestCase):
)
self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
+ context: EventContext
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
+ current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(
@@ -545,7 +598,9 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(group_name, context.state_group)
@defer.inlineCallbacks
- def test_trivial_annotate_state(self):
+ def test_trivial_annotate_state(
+ self,
+ ) -> Generator["defer.Deferred[object]", Any, None]:
prev_event_id = "prev_event_id"
event = create_event(
type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
@@ -568,8 +623,10 @@ class StateTestCase(unittest.TestCase):
)
self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
+ context: EventContext
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
+ prev_state_ids: StateMap[str]
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
@@ -577,7 +634,9 @@ class StateTestCase(unittest.TestCase):
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
- def test_resolve_message_conflict(self):
+ def test_resolve_message_conflict(
+ self,
+ ) -> Generator["defer.Deferred[Any]", Any, None]:
prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
@@ -605,10 +664,12 @@ class StateTestCase(unittest.TestCase):
self.dummy_store.register_events(old_state_1)
self.dummy_store.register_events(old_state_2)
+ context: EventContext
context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
+ current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -616,7 +677,9 @@ class StateTestCase(unittest.TestCase):
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
- def test_resolve_state_conflict(self):
+ def test_resolve_state_conflict(
+ self,
+ ) -> Generator["defer.Deferred[Any]", Any, None]:
prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
@@ -645,12 +708,14 @@ class StateTestCase(unittest.TestCase):
store = _DummyStore()
store.register_events(old_state_1)
store.register_events(old_state_2)
- self.dummy_store.get_events = store.get_events
+ self.dummy_store.get_events = store.get_events # type: ignore[assignment]
+ context: EventContext
context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
+ current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -658,7 +723,9 @@ class StateTestCase(unittest.TestCase):
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
- def test_standard_depth_conflict(self):
+ def test_standard_depth_conflict(
+ self,
+ ) -> Generator["defer.Deferred[Any]", Any, None]:
prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
@@ -700,12 +767,14 @@ class StateTestCase(unittest.TestCase):
store = _DummyStore()
store.register_events(old_state_1)
store.register_events(old_state_2)
- self.dummy_store.get_events = store.get_events
+ self.dummy_store.get_events = store.get_events # type: ignore[assignment]
+ context: EventContext
context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
+ current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@@ -740,8 +809,14 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def _get_context(
- self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
- ):
+ self,
+ event: EventBase,
+ prev_event_id_1: str,
+ old_state_1: Collection[EventBase],
+ prev_event_id_2: str,
+ old_state_2: Collection[EventBase],
+ ) -> Generator["defer.Deferred[object]", Any, EventContext]:
+ sg1: int
sg1 = yield defer.ensureDeferred(
self.dummy_store.store_state_group(
prev_event_id_1,
@@ -753,6 +828,7 @@ class StateTestCase(unittest.TestCase):
)
self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1)
+ sg2: int
sg2 = yield defer.ensureDeferred(
self.dummy_store.store_state_group(
prev_event_id_2,
@@ -767,7 +843,7 @@ class StateTestCase(unittest.TestCase):
result = yield defer.ensureDeferred(self.state.compute_event_context(event))
return result
- def test_make_state_cache_entry(self):
+ def test_make_state_cache_entry(self) -> None:
"Test that calculating a prev_group and delta is correct"
new_state = {
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index abd7459a8c..52424aa087 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -14,9 +14,12 @@
from unittest.mock import Mock
-from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.internet.interfaces import IReactorTime
+from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
from synapse.rest.client.register import register_servlets
+from synapse.server import HomeServer
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -25,7 +28,7 @@ from tests import unittest
class TermsTestCase(unittest.HomeserverTestCase):
servlets = [register_servlets]
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config.update(
{
@@ -40,17 +43,21 @@ class TermsTestCase(unittest.HomeserverTestCase):
)
return config
- def prepare(self, reactor, clock, hs):
- self.clock = MemoryReactorClock()
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ # type-ignore: mypy-zope doesn't seem to recognise that MemoryReactorClock
+ # implements IReactorTime, via inheritance from twisted.internet.testing.Clock
+ self.clock: IReactorTime = MemoryReactorClock() # type: ignore[assignment]
self.hs_clock = Clock(self.clock)
self.url = "/_matrix/client/r0/register"
self.registration_handler = Mock()
self.auth_handler = Mock()
self.device_handler = Mock()
- def test_ui_auth(self):
+ def test_ui_auth(self) -> None:
# Do a UI auth request
- request_data = {"username": "kermit", "password": "monkey"}
+ request_data: JsonDict = {"username": "kermit", "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.code, 401, channel.result)
diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index d04bcae0fa..5cd698147e 100644
--- a/tests/test_test_utils.py
+++ b/tests/test_test_utils.py
@@ -17,25 +17,25 @@ from tests.utils import MockClock
class MockClockTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock = MockClock()
- def test_advance_time(self):
+ def test_advance_time(self) -> None:
start_time = self.clock.time()
self.clock.advance_time(20)
self.assertEqual(20, self.clock.time() - start_time)
- def test_later(self):
+ def test_later(self) -> None:
invoked = [0, 0]
- def _cb0():
+ def _cb0() -> None:
invoked[0] = 1
self.clock.call_later(10, _cb0)
- def _cb1():
+ def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
@@ -51,15 +51,15 @@ class MockClockTestCase(unittest.TestCase):
self.assertTrue(invoked[1])
- def test_cancel_later(self):
+ def test_cancel_later(self) -> None:
invoked = [0, 0]
- def _cb0():
+ def _cb0() -> None:
invoked[0] = 1
t0 = self.clock.call_later(10, _cb0)
- def _cb1():
+ def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
diff --git a/tests/test_types.py b/tests/test_types.py
index 1111169384..c491cc9a96 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -43,34 +43,34 @@ class IsMineIDTests(unittest.HomeserverTestCase):
class UserIDTestCase(unittest.HomeserverTestCase):
- def test_parse(self):
+ def test_parse(self) -> None:
user = UserID.from_string("@1234abcd:test")
self.assertEqual("1234abcd", user.localpart)
self.assertEqual("test", user.domain)
self.assertEqual(True, self.hs.is_mine(user))
- def test_parse_rejects_empty_id(self):
+ def test_parse_rejects_empty_id(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("")
- def test_parse_rejects_missing_sigil(self):
+ def test_parse_rejects_missing_sigil(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("alice:example.com")
- def test_parse_rejects_missing_separator(self):
+ def test_parse_rejects_missing_separator(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("@alice.example.com")
- def test_validation_rejects_missing_domain(self):
+ def test_validation_rejects_missing_domain(self) -> None:
self.assertFalse(UserID.is_valid("@alice:"))
- def test_build(self):
+ def test_build(self) -> None:
user = UserID("5678efgh", "my.domain")
self.assertEqual(user.to_string(), "@5678efgh:my.domain")
- def test_compare(self):
+ def test_compare(self) -> None:
userA = UserID.from_string("@userA:my.domain")
userAagain = UserID.from_string("@userA:my.domain")
userB = UserID.from_string("@userB:my.domain")
@@ -80,43 +80,43 @@ class UserIDTestCase(unittest.HomeserverTestCase):
class RoomAliasTestCase(unittest.HomeserverTestCase):
- def test_parse(self):
+ def test_parse(self) -> None:
room = RoomAlias.from_string("#channel:test")
self.assertEqual("channel", room.localpart)
self.assertEqual("test", room.domain)
self.assertEqual(True, self.hs.is_mine(room))
- def test_build(self):
+ def test_build(self) -> None:
room = RoomAlias("channel", "my.domain")
self.assertEqual(room.to_string(), "#channel:my.domain")
- def test_validate(self):
+ def test_validate(self) -> None:
id_string = "#test:domain,test"
self.assertFalse(RoomAlias.is_valid(id_string))
class MapUsernameTestCase(unittest.TestCase):
- def testPassThrough(self):
+ def test_pass_througuh(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
- def testUpperCase(self):
+ def test_upper_case(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
self.assertEqual(
map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
"t_e_s_t__1234",
)
- def testSymbols(self):
+ def test_symbols(self) -> None:
self.assertEqual(
map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
)
- def testLeadingUnderscore(self):
+ def test_leading_underscore(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
- def testNonAscii(self):
+ def test_non_ascii(self) -> None:
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e62ebcc6a5..e5dae670a7 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -20,12 +20,13 @@ import sys
import warnings
from asyncio import Future
from binascii import unhexlify
-from typing import Awaitable, Callable, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
from unittest.mock import Mock
import attr
import zope.interface
+from twisted.internet.interfaces import IProtocol
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.web.http import RESPONSES
@@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse
from synapse.types import JsonDict
+if TYPE_CHECKING:
+ from sys import UnraisableHookArgs
+
TV = TypeVar("TV")
@@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]:
unraisable_exceptions = []
orig_unraisablehook = sys.unraisablehook
- def unraisablehook(unraisable):
+ def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
unraisable_exceptions.append(unraisable.exc_value)
- def cleanup():
+ def cleanup() -> None:
"""
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
"""
sys.unraisablehook = orig_unraisablehook
if unraisable_exceptions:
- raise unraisable_exceptions.pop()
+ exc = unraisable_exceptions.pop()
+ assert exc is not None
+ raise exc
sys.unraisablehook = unraisablehook
return cleanup
-def simple_async_mock(return_value=None, raises=None) -> Mock:
+def simple_async_mock(
+ return_value: Optional[TV] = None, raises: Optional[Exception] = None
+) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
+ async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
if raises:
raise raises
return return_value
@@ -125,14 +133,14 @@ class FakeResponse: # type: ignore[misc]
headers: Headers = attr.Factory(Headers)
@property
- def phrase(self):
+ def phrase(self) -> bytes:
return RESPONSES.get(self.code, b"Unknown Status")
@property
- def length(self):
+ def length(self) -> int:
return len(self.body)
- def deliverBody(self, protocol):
+ def deliverBody(self, protocol: IProtocol) -> None:
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8027c7a856..a6330ed840 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
@@ -32,7 +32,7 @@ async def inject_member_event(
membership: str,
target: Optional[str] = None,
extra_content: Optional[dict] = None,
- **kwargs,
+ **kwargs: Any,
) -> EventBase:
"""Inject a membership event into a room."""
if target is None:
@@ -57,7 +57,7 @@ async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs,
+ **kwargs: Any,
) -> EventBase:
"""Inject a generic event into a room
@@ -82,7 +82,7 @@ async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs,
+ **kwargs: Any,
) -> Tuple[EventBase, EventContext]:
if room_version is None:
room_version = await hs.get_datastores().main.get_room_version_id(
@@ -92,8 +92,13 @@ async def create_event(
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
- event, context = await hs.get_event_creation_handler().create_new_client_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
+ context = await unpersisted_context.persist(event)
+
return event, context
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index e878af5f12..189c697efb 100644
--- a/tests/test_utils/html_parsers.py
+++ b/tests/test_utils/html_parsers.py
@@ -13,13 +13,13 @@
# limitations under the License.
from html.parser import HTMLParser
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, NoReturn, Optional, Tuple
class TestHtmlParser(HTMLParser):
"""A generic HTML page parser which extracts useful things from the HTML"""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
# a list of links found in the doc
@@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser):
assert input_name
self.hiddens[input_name] = attr_dict["value"]
- def error(_, message):
+ def error(self, message: str) -> NoReturn:
raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 304c7b98c5..b522163a34 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler):
tx_log = twisted.logger.Logger()
- def emit(self, record):
+ def emit(self, record: logging.LogRecord) -> None:
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit(
@@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler):
)
-def setup_logging():
+def setup_logging() -> None:
"""Configure the python logging appropriately for the tests.
(Logs will end up in _trial_temp.)
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 1461d23ee8..d555b24255 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -14,7 +14,7 @@
import json
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, ContextManager, Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
from urllib.parse import parse_qs
@@ -77,14 +77,14 @@ class FakeOidcServer:
self._id_token_overrides: Dict[str, Any] = {}
- def reset_mocks(self):
+ def reset_mocks(self) -> None:
self.request.reset_mock()
self.get_jwks_handler.reset_mock()
self.get_metadata_handler.reset_mock()
self.get_userinfo_handler.reset_mock()
self.post_token_handler.reset_mock()
- def patch_homeserver(self, hs: HomeServer):
+ def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
"""Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
This patch should be used whenever the HS is expected to perform request to the
@@ -188,7 +188,7 @@ class FakeOidcServer:
return self._sign(logout_token)
- def id_token_override(self, overrides: dict):
+ def id_token_override(self, overrides: dict) -> ContextManager[dict]:
"""Temporarily patch the ID token generated by the token endpoint."""
return patch.object(self, "_id_token_overrides", overrides)
@@ -247,7 +247,7 @@ class FakeOidcServer:
metadata: bool = False,
token: bool = False,
userinfo: bool = False,
- ):
+ ) -> ContextManager[Dict[str, Mock]]:
"""A context which makes a set of endpoints return a 500 error.
Args:
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index d0b9ad5454..2801a950a8 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -35,6 +35,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self._storage_controllers = self.hs.get_storage_controllers()
+ assert self._storage_controllers.persistence is not None
+ self._persistence = self._storage_controllers.persistence
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@@ -175,12 +177,11 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ context = self.get_success(unpersisted_context.persist(event))
+ self.get_success(self._persistence.persist_event(event, context))
return event
def _inject_room_member(
@@ -202,13 +203,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ self.get_success(self._persistence.persist_event(event, context))
return event
def _inject_message(
@@ -226,13 +226,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ self.get_success(self._persistence.persist_event(event, context))
return event
def _inject_outlier(self) -> EventBase:
@@ -250,7 +249,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True
self.get_success(
- self._storage_controllers.persistence.persist_event(
+ self._persistence.persist_event(
event, EventContext.for_outlier(self._storage_controllers)
)
)
@@ -258,7 +257,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
- def test_out_of_band_invite_rejection(self):
+ def test_out_of_band_invite_rejection(self) -> None:
# this is where we have received an invite event over federation, and then
# rejected it.
invite_pdu = {
diff --git a/tests/unittest.py b/tests/unittest.py
index 50aa5abda9..6625fe1688 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
-from twisted.test.proto_helpers import MemoryReactor
+from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
from twisted.trial import unittest
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -82,7 +82,7 @@ from tests.server import (
)
from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
-from tests.utils import default_config, setupdb
+from tests.utils import checked_cast, default_config, setupdb
setupdb()
setup_logging()
@@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.utils import RestHelper
- self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
+ self.helper = RestHelper(
+ self.hs,
+ checked_cast(MemoryReactorClock, self.hs.get_reactor()),
+ self.site,
+ getattr(self, "user_id", None),
+ )
if hasattr(self, "user_id"):
if self.hijack_auth:
@@ -315,7 +320,7 @@ class HomeserverTestCase(TestCase):
# This has to be a function and not just a Mock, because
# `self.helper.auth_user_id` is temporarily reassigned in some tests
- async def get_requester(*args, **kwargs) -> Requester:
+ async def get_requester(*args: Any, **kwargs: Any) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
user_id=UserID.from_string(self.helper.auth_user_id),
@@ -361,7 +366,9 @@ class HomeserverTestCase(TestCase):
store.db_pool.updates.do_next_background_update(False), by=0.1
)
- def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock):
+ def make_homeserver(
+ self, reactor: ThreadedMemoryReactorClock, clock: Clock
+ ) -> HomeServer:
"""
Make and return a homeserver.
@@ -716,7 +723,7 @@ class HomeserverTestCase(TestCase):
event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_creator.create_event(
requester,
{
@@ -728,7 +735,7 @@ class HomeserverTestCase(TestCase):
prev_event_ids=prev_event_ids,
)
)
-
+ context = self.get_success(unpersisted_context.persist(event))
if soft_failed:
event.internal_metadata.soft_failed = True
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9529ee53c8..5f8f4e76b5 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -54,6 +54,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
self.pump()
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ assert new_timings is not None
self.assertEqual(new_timings.failure_ts, failure_ts)
self.assertEqual(new_timings.retry_last_ts, failure_ts)
self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
@@ -82,6 +83,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
self.pump()
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ assert new_timings is not None
self.assertEqual(new_timings.failure_ts, failure_ts)
self.assertEqual(new_timings.retry_last_ts, retry_ts)
self.assertGreaterEqual(
diff --git a/tests/utils.py b/tests/utils.py
index d76bf9716a..a0ac11bc5c 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,7 +15,7 @@
import atexit
import os
-from typing import Any, Callable, Dict, List, Tuple, Union, overload
+from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload
import attr
from typing_extensions import Literal, ParamSpec
@@ -335,6 +335,33 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
},
)
- event, context = await event_creation_handler.create_new_client_event(builder)
+ event, unpersisted_context = await event_creation_handler.create_new_client_event(
+ builder
+ )
+ context = await unpersisted_context.persist(event)
await persistence_store.persist_event(event, context)
+
+
+T = TypeVar("T")
+
+
+def checked_cast(type: Type[T], x: object) -> T:
+ """A version of typing.cast that is checked at runtime.
+
+ We have our own function for this for two reasons:
+
+ 1. typing.cast itself is deliberately a no-op at runtime, see
+ https://docs.python.org/3/library/typing.html#typing.cast
+ 2. To help workaround a mypy-zope bug https://github.com/Shoobx/mypy-zope/issues/91
+ where mypy would erroneously consider `isinstance(x, type)` to be false in all
+ circumstances.
+
+ For this to make sense, `T` needs to be something that `isinstance` can check; see
+ https://docs.python.org/3/library/functions.html?highlight=isinstance#isinstance
+ https://docs.python.org/3/glossary.html#term-abstract-base-class
+ https://docs.python.org/3/library/typing.html#typing.runtime_checkable
+ for more details.
+ """
+ assert isinstance(x, type)
+ return x
|