diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 1cbb059357..0b22afdc75 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -24,6 +24,7 @@ from synapse.appservice.scheduler import (
)
from synapse.logging.context import make_deferred_yieldable
from synapse.server import HomeServer
+from synapse.types import DeviceListUpdates
from synapse.util import Clock
from tests import unittest
@@ -70,6 +71,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[], # txn made and saved
one_time_key_counts={},
unused_fallback_keys={},
+ device_list_summary=DeviceListUpdates(),
)
self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
@@ -96,6 +98,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[], # txn made and saved
one_time_key_counts={},
unused_fallback_keys={},
+ device_list_summary=DeviceListUpdates(),
)
self.assertEqual(0, txn.send.call_count) # txn not sent though
self.assertEqual(0, txn.complete.call_count) # or completed
@@ -124,6 +127,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
+ device_list_summary=DeviceListUpdates(),
)
self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made
self.assertEqual(1, self.recoverer.recover.call_count) # and invoked
@@ -225,7 +229,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service = Mock(id=4)
event = Mock()
self.scheduler.enqueue_for_appservice(service, events=[event])
- self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None)
+ self.txn_ctrl.send.assert_called_once_with(
+ service, [event], [], [], None, None, DeviceListUpdates()
+ )
def test_send_single_event_with_queue(self):
d = defer.Deferred()
@@ -240,12 +246,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# (call enqueue_for_appservice multiple times deliberately)
self.scheduler.enqueue_for_appservice(service, events=[event2])
self.scheduler.enqueue_for_appservice(service, events=[event3])
- self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None)
+ self.txn_ctrl.send.assert_called_with(
+ service, [event], [], [], None, None, DeviceListUpdates()
+ )
self.assertEqual(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
self.txn_ctrl.send.assert_called_with(
- service, [event2, event3], [], [], None, None
+ service, [event2, event3], [], [], None, None, DeviceListUpdates()
)
self.assertEqual(2, self.txn_ctrl.send.call_count)
@@ -272,15 +280,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# send events for different ASes and make sure they are sent
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
- self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None)
+ self.txn_ctrl.send.assert_called_with(
+ srv1, [srv_1_event], [], [], None, None, DeviceListUpdates()
+ )
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
- self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None)
+ self.txn_ctrl.send.assert_called_with(
+ srv2, [srv_2_event], [], [], None, None, DeviceListUpdates()
+ )
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
- self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None)
+ self.txn_ctrl.send.assert_called_with(
+ srv2, [srv_2_event2], [], [], None, None, DeviceListUpdates()
+ )
self.assertEqual(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self):
@@ -300,17 +314,17 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Expect the first event to be sent immediately.
self.txn_ctrl.send.assert_called_with(
- service, [event_list[0]], [], [], None, None
+ service, [event_list[0]], [], [], None, None, DeviceListUpdates()
)
srv_1_defer.callback(service)
# Then send the next 100 events
self.txn_ctrl.send.assert_called_with(
- service, event_list[1:101], [], [], None, None
+ service, event_list[1:101], [], [], None, None, DeviceListUpdates()
)
srv_2_defer.callback(service)
# Then the final 99 events
self.txn_ctrl.send.assert_called_with(
- service, event_list[101:], [], [], None, None
+ service, event_list[101:], [], [], None, None, DeviceListUpdates()
)
self.assertEqual(3, self.txn_ctrl.send.call_count)
@@ -320,7 +334,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = [Mock(name="event")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
- service, [], event_list, [], None, None
+ service, [], event_list, [], None, None, DeviceListUpdates()
)
def test_send_multiple_ephemeral_no_queue(self):
@@ -329,7 +343,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
- service, [], event_list, [], None, None
+ service, [], event_list, [], None, None, DeviceListUpdates()
)
def test_send_single_ephemeral_with_queue(self):
@@ -345,13 +359,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Send more events: expect send() to NOT be called multiple times.
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
- self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None)
+ self.txn_ctrl.send.assert_called_with(
+ service, [], event_list_1, [], None, None, DeviceListUpdates()
+ )
self.assertEqual(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send
d.callback(service)
# Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(
- service, [], event_list_2 + event_list_3, [], None, None
+ service,
+ [],
+ event_list_2 + event_list_3,
+ [],
+ None,
+ None,
+ DeviceListUpdates(),
)
self.assertEqual(2, self.txn_ctrl.send.call_count)
@@ -365,8 +387,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = first_chunk + second_chunk
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
- service, [], first_chunk, [], None, None
+ service, [], first_chunk, [], None, None, DeviceListUpdates()
)
d.callback(service)
- self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None)
+ self.txn_ctrl.send.assert_called_with(
+ service, [], second_chunk, [], None, None, DeviceListUpdates()
+ )
self.assertEqual(2, self.txn_ctrl.send.call_count)
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index 694020fbef..06e0545a4f 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -28,8 +28,8 @@ from tests import unittest
SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
KEY_ALG = "ed25519"
-KEY_VER = 1
-KEY_NAME = "%s:%d" % (KEY_ALG, KEY_VER)
+KEY_VER = "1"
+KEY_NAME = "%s:%s" % (KEY_ALG, KEY_VER)
HOSTNAME = "domain"
@@ -39,7 +39,7 @@ class EventSigningTestCase(unittest.TestCase):
# NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been
# monkeypatched to include new `alg` and `version` attributes. This is captured
# by the `signedjson.types.SigningKey` protocol.
- self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey(
+ self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( # type: ignore[assignment]
SIGNING_KEY_SEED
)
self.signing_key.alg = KEY_ALG
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index e90592855a..a6e91956af 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -14,6 +14,7 @@
from typing import Optional
from unittest.mock import Mock
+from parameterized import parameterized_class
from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey
@@ -154,6 +155,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
)
+@parameterized_class(
+ [
+ {"enable_room_poke_code_path": False},
+ {"enable_room_poke_code_path": True},
+ ]
+)
class FederationSenderDevicesTestCases(HomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -168,17 +175,21 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def default_config(self):
c = super().default_config()
c["send_federation"] = True
+ c["use_new_device_lists_changes_in_room"] = self.enable_room_poke_code_path
return c
def prepare(self, reactor, clock, hs):
- # stub out get_users_who_share_room_with_user so that it claims that
- # `@user2:host2` is in the room
- def get_users_who_share_room_with_user(user_id):
+ # stub out `get_rooms_for_user` and `get_users_in_room` so that the
+ # server thinks the user shares a room with `@user2:host2`
+ def get_rooms_for_user(user_id):
+ return defer.succeed({"!room:host1"})
+
+ hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
+
+ def get_users_in_room(room_id):
return defer.succeed({"@user2:host2"})
- hs.get_datastores().main.get_users_who_share_room_with_user = (
- get_users_who_share_room_with_user
- )
+ hs.get_datastores().main.get_users_in_room = get_users_in_room
# whenever send_transaction is called, record the edu data
self.edus = []
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 648a01618e..d21c11b716 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -23,7 +23,7 @@ from synapse.server import HomeServer
from synapse.types import RoomAlias
from tests.test_utils import event_injection
-from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+from tests.unittest import FederatingHomeserverTestCase, TestCase
class KnockingStrippedStateEventHelperMixin(TestCase):
@@ -221,7 +221,6 @@ class FederationKnockingTestCase(
return super().prepare(reactor, clock, homeserver)
- @override_config({"experimental_features": {"msc2403_enabled": True}})
def test_room_state_returned_when_knocking(self):
"""
Tests that specific, stripped state events from a room are returned after
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index cead9f90df..8c72cf6b30 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -15,6 +15,8 @@
from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock
+from parameterized import parameterized
+
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
@@ -471,6 +473,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
to_device_messages,
_otks,
_fbks,
+ _device_list_summary,
) = self.send_mock.call_args[0]
# Assert that this was the same to-device message that local_user sent
@@ -583,7 +586,15 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
service_id_to_message_count: Dict[str, int] = {}
for call in self.send_mock.call_args_list:
- service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0]
+ (
+ service,
+ _events,
+ _ephemeral,
+ to_device_messages,
+ _otks,
+ _fbks,
+ _device_list_summary,
+ ) = call[0]
# Check that this was made to an interested service
self.assertIn(service, interested_appservices)
@@ -627,6 +638,114 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
return appservice
+class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase):
+ """
+ Tests that the ApplicationServicesHandler sends device list updates to application
+ services correctly.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ # Allow us to modify cached feature flags mid-test
+ self.as_handler = hs.get_application_service_handler()
+
+ # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
+ # will be sent over the wire
+ self.put_json = simple_async_mock()
+ hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment]
+
+ # Mock out application services, and allow defining our own in tests
+ self._services: List[ApplicationService] = []
+ self.hs.get_datastores().main.get_app_services = Mock(
+ return_value=self._services
+ )
+
+ # Test across a variety of configuration values
+ @parameterized.expand(
+ [
+ (True, True, True),
+ (True, False, False),
+ (False, True, False),
+ (False, False, False),
+ ]
+ )
+ def test_application_service_receives_device_list_updates(
+ self,
+ experimental_feature_enabled: bool,
+ as_supports_txn_extensions: bool,
+ as_should_receive_device_list_updates: bool,
+ ):
+ """
+ Tests that an application service receives notice of changed device
+ lists for a user, when a user changes their device lists.
+
+ Arguments above are populated by parameterized.
+
+ Args:
+ as_should_receive_device_list_updates: Whether we expect the AS to receive the
+ device list changes.
+ experimental_feature_enabled: Whether the "msc3202_transaction_extensions" experimental
+ feature is enabled. This feature must be enabled for device lists to ASs to work.
+ as_supports_txn_extensions: Whether the application service has explicitly registered
+ to receive information defined by MSC3202 - which includes device list changes.
+ """
+ # Change whether the experimental feature is enabled or disabled before making
+ # device list changes
+ self.as_handler._msc3202_transaction_extensions_enabled = (
+ experimental_feature_enabled
+ )
+
+ # Create an appservice that is interested in "local_user"
+ appservice = ApplicationService(
+ token=random_string(10),
+ hostname="example.com",
+ id=random_string(10),
+ sender="@as:example.com",
+ rate_limited=False,
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": "@local_user:.+",
+ "exclusive": False,
+ }
+ ],
+ },
+ supports_ephemeral=True,
+ msc3202_transaction_extensions=as_supports_txn_extensions,
+ # Must be set for Synapse to try pushing data to the AS
+ hs_token="abcde",
+ url="some_url",
+ )
+
+ # Register the application service
+ self._services.append(appservice)
+
+ # Register a user on the homeserver
+ self.local_user = self.register_user("local_user", "password")
+ self.local_user_token = self.login("local_user", "password")
+
+ if as_should_receive_device_list_updates:
+ # Ensure that the resulting JSON uses the unstable prefix and contains the
+ # expected users
+ self.put_json.assert_called_once()
+ json_body = self.put_json.call_args[1]["json_body"]
+
+ # Our application service should have received a device list update with
+ # "local_user" in the "changed" list
+ device_list_dict = json_body.get("org.matrix.msc3202.device_lists", {})
+ self.assertEqual([], device_list_dict["left"])
+ self.assertEqual([self.local_user], device_list_dict["changed"])
+
+ else:
+ # No device list changes should have been sent out
+ self.put_json.assert_not_called()
+
+
class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
# Argument indices for pulling out arguments from a `send_mock`.
ARG_OTK_COUNTS = 4
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index 3a10791226..7586e472b5 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -44,21 +44,20 @@ class DeactivateAccountTestCase(HomeserverTestCase):
Deactivates the account `self.user` using `self.token` and asserts
that it returns a 200 success code.
"""
- req = self.get_success(
- self.make_request(
- "POST",
- "account/deactivate",
- {
- "auth": {
- "type": "m.login.password",
- "user": self.user,
- "password": "pass",
- },
- "erase": True,
+ req = self.make_request(
+ "POST",
+ "account/deactivate",
+ {
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user,
+ "password": "pass",
},
- access_token=self.token,
- )
+ "erase": True,
+ },
+ access_token=self.token,
)
+
self.assertEqual(req.code, HTTPStatus.OK, req)
def test_global_account_data_deleted_upon_deactivation(self) -> None:
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index ac21a28c43..8c74ed1fcf 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -463,8 +463,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 400)
- res = self.get_success(self.handler.query_local_devices({local_user: None}))
- self.assertDictEqual(res, {local_user: {}})
+ query_res = self.get_success(
+ self.handler.query_local_devices({local_user: None})
+ )
+ self.assertDictEqual(query_res, {local_user: {}})
def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded"""
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 89078fc637..060ba5f517 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -20,17 +20,17 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
-from synapse.events import EventBase
+from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import create_requester
from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.test_utils import event_injection
logger = logging.getLogger(__name__)
@@ -39,7 +39,7 @@ def generate_fake_event_id() -> str:
return "$fake_" + random_string(43)
-class FederationTestCase(unittest.HomeserverTestCase):
+class FederationTestCase(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -219,41 +219,77 @@ class FederationTestCase(unittest.HomeserverTestCase):
# create the room
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
- requester = create_requester(user_id)
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ # we need a user on the remote server to be a member, so that we can send
+ # extremity-causing events.
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room_id, f"@user:{self.OTHER_SERVER_NAME}", "join"
+ )
+ )
- ev1 = self.helper.send(room_id, "first message", tok=tok)
+ send_result = self.helper.send(room_id, "first message", tok=tok)
+ ev1 = self.get_success(
+ self.store.get_event(send_result["event_id"], allow_none=False)
+ )
+ current_state = self.get_success(
+ self.store.get_events_as_list(
+ (self.get_success(self.store.get_current_state_ids(room_id))).values()
+ )
+ )
# Create "many" backward extremities. The magic number we're trying to
# create more than is 5 which corresponds to the number of backward
# extremities we slice off in `_maybe_backfill_inner`
+ federation_event_handler = self.hs.get_federation_event_handler()
for _ in range(0, 8):
- event_handler = self.hs.get_event_creation_handler()
- event, context = self.get_success(
- event_handler.create_event(
- requester,
+ event = make_event_from_dict(
+ self.add_hashes_and_signatures(
{
+ "origin_server_ts": 1,
"type": "m.room.message",
"content": {
"msgtype": "m.text",
"body": "message connected to fake event",
},
"room_id": room_id,
- "sender": user_id,
+ "sender": f"@user:{self.OTHER_SERVER_NAME}",
+ "prev_events": [
+ ev1.event_id,
+ # We're creating an backward extremity each time thanks
+ # to this fake event
+ generate_fake_event_id(),
+ ],
+ # lazy: *everything* is an auth event
+ "auth_events": [ev.event_id for ev in current_state],
+ "depth": ev1.depth + 1,
},
- prev_event_ids=[
- ev1["event_id"],
- # We're creating an backward extremity each time thanks
- # to this fake event
- generate_fake_event_id(),
- ],
- )
+ room_version,
+ ),
+ room_version,
)
+
+ # we poke this directly into _process_received_pdu, to avoid the
+ # federation handler wanting to backfill the fake event.
self.get_success(
- event_handler.handle_new_client_event(requester, event, context)
+ federation_event_handler._process_received_pdu(
+ self.OTHER_SERVER_NAME, event, state=current_state
+ )
)
+ # we should now have 8 backwards extremities.
+ backwards_extremities = self.get_success(
+ self.store.db_pool.simple_select_list(
+ "event_backward_extremities",
+ keyvalues={"room_id": room_id},
+ retcols=["event_id"],
+ )
+ )
+ self.assertEqual(len(backwards_extremities), 8)
+
current_depth = 1
limit = 100
with LoggingContext("receive_pdu"):
@@ -339,7 +375,8 @@ class FederationTestCase(unittest.HomeserverTestCase):
member_event.signatures = member_event_dict["signatures"]
# Add the new member_event to the StateMap
- prev_state_map[
+ updated_state_map = dict(prev_state_map)
+ updated_state_map[
(member_event.type, member_event.state_key)
] = member_event.event_id
auth_events.append(member_event)
@@ -363,7 +400,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
prev_event_ids=message_event_dict["prev_events"],
auth_event_ids=self._event_auth_handler.compute_auth_events(
builder,
- prev_state_map,
+ updated_state_map,
for_verification=False,
),
depth=message_event_dict["depth"],
@@ -496,8 +533,8 @@ class EventFromPduTestCase(TestCase):
def test_invalid_numbers(self) -> None:
"""Invalid values for an integer should be rejected, all floats should be rejected."""
for value in [
- -(2 ** 53),
- 2 ** 53,
+ -(2**53),
+ 2**53,
1.0,
float("inf"),
float("-inf"),
@@ -524,7 +561,7 @@ class EventFromPduTestCase(TestCase):
event_from_pdu_json(
{
"type": EventTypes.Message,
- "content": {"foo": [{"bar": 2 ** 56}]},
+ "content": {"foo": [{"bar": 2**56}]},
"room_id": "!room:test",
"sender": "@user:test",
"depth": 1,
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
new file mode 100644
index 0000000000..489ba57736
--- /dev/null
+++ b/tests/handlers/test_federation_event.py
@@ -0,0 +1,225 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest import mock
+
+from synapse.events import make_event_from_dict
+from synapse.events.snapshot import EventContext
+from synapse.federation.transport.client import StateRequestResponse
+from synapse.logging.context import LoggingContext
+from synapse.rest import admin
+from synapse.rest.client import login, room
+
+from tests import unittest
+from tests.test_utils import event_injection, make_awaitable
+
+
+class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ # mock out the federation transport client
+ self.mock_federation_transport_client = mock.Mock(
+ spec=["get_room_state_ids", "get_room_state", "get_event"]
+ )
+ return super().setup_test_homeserver(
+ federation_transport_client=self.mock_federation_transport_client
+ )
+
+ def test_process_pulled_event_with_missing_state(self) -> None:
+ """Ensure that we correctly handle pulled events with lots of missing state
+
+ In this test, we pretend we are processing a "pulled" event (eg, via backfill
+ or get_missing_events). The pulled event has a prev_event we haven't previously
+ seen, so the server requests the state at that prev_event. There is a lot
+ of state we don't have, so we expect the server to make a /state request.
+
+ We check that the pulled event is correctly persisted, and that the state is
+ as we expect.
+ """
+ return self._test_process_pulled_event_with_missing_state(False)
+
+ def test_process_pulled_event_with_missing_state_where_prev_is_outlier(
+ self,
+ ) -> None:
+ """Ensure that we correctly handle pulled events with lots of missing state
+
+ A slight modification to test_process_pulled_event_with_missing_state. Again
+ we have a "pulled" event which refers to a prev_event with lots of state,
+ but in this case we already have the prev_event (as an outlier, obviously -
+ if it were a regular event, we wouldn't need to request the state).
+ """
+ return self._test_process_pulled_event_with_missing_state(True)
+
+ def _test_process_pulled_event_with_missing_state(
+ self, prev_exists_as_outlier: bool
+ ) -> None:
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+ state_storage = self.hs.get_storage().state
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # allow the remote user to send state events
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"events_default": 0, "state_default": 0},
+ tok=tok,
+ )
+
+ # add the remote user to the room
+ member_event = self.get_success(
+ event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+ )
+
+ initial_state_map = self.get_success(main_store.get_current_state_ids(room_id))
+
+ auth_event_ids = [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ initial_state_map[("m.room.join_rules", "")],
+ member_event.event_id,
+ ]
+
+ # mock up a load of state events which we are missing
+ state_events = [
+ make_event_from_dict(
+ self.add_hashes_and_signatures(
+ {
+ "type": "test_state_type",
+ "state_key": f"state_{i}",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [member_event.event_id],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 10,
+ "content": {"body": f"state_{i}"},
+ }
+ ),
+ room_version,
+ )
+ for i in range(1, 10)
+ ]
+
+ # this is the state that we are going to claim is active at the prev_event.
+ state_at_prev_event = state_events + self.get_success(
+ main_store.get_events_as_list(initial_state_map.values())
+ )
+
+ # mock up a prev event.
+ # Depending on the test, we either persist this upfront (as an outlier),
+ # or let the server request it.
+ prev_event = make_event_from_dict(
+ self.add_hashes_and_signatures(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 11,
+ "content": {"body": "missing_prev"},
+ }
+ ),
+ room_version,
+ )
+ if prev_exists_as_outlier:
+ prev_event.internal_metadata.outlier = True
+ persistence = self.hs.get_storage().persistence
+ self.get_success(
+ persistence.persist_event(prev_event, EventContext.for_outlier())
+ )
+ else:
+
+ async def get_event(destination: str, event_id: str, timeout=None):
+ self.assertEqual(destination, self.OTHER_SERVER_NAME)
+ self.assertEqual(event_id, prev_event.event_id)
+ return {"pdus": [prev_event.get_pdu_json()]}
+
+ self.mock_federation_transport_client.get_event.side_effect = get_event
+
+ # mock up a regular event to pass into _process_pulled_event
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [prev_event.event_id],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled"},
+ }
+ ),
+ room_version,
+ )
+
+ # we expect an outbound request to /state_ids, so stub that out
+ self.mock_federation_transport_client.get_room_state_ids.return_value = (
+ make_awaitable(
+ {
+ "pdu_ids": [e.event_id for e in state_at_prev_event],
+ "auth_chain_ids": [],
+ }
+ )
+ )
+
+ # we also expect an outbound request to /state
+ self.mock_federation_transport_client.get_room_state.return_value = (
+ make_awaitable(
+ StateRequestResponse(auth_events=[], state=state_at_prev_event)
+ )
+ )
+
+ # we have to bump the clock a bit, to keep the retry logic in
+ # FederationClient.get_pdu happy
+ self.reactor.advance(60000)
+
+ # Finally, the call under test: send the pulled event into _process_pulled_event
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME, pulled_event, backfilled=False
+ )
+ )
+
+ # check that the event is correctly persisted
+ persisted = self.get_success(main_store.get_event(pulled_event.event_id))
+ self.assertIsNotNone(persisted, "pulled event was not persisted at all")
+ self.assertFalse(
+ persisted.internal_metadata.is_outlier(), "pulled event was an outlier"
+ )
+
+ # check that the state at that event is as expected
+ state = self.get_success(
+ state_storage.get_state_ids_for_event(pulled_event.event_id)
+ )
+ expected_state = {
+ (e.type, e.state_key): e.event_id for e in state_at_prev_event
+ }
+ self.assertEqual(state, expected_state)
+
+ if prev_exists_as_outlier:
+ self.mock_federation_transport_client.get_event.assert_not_called()
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 014815db6e..9684120c70 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -354,10 +354,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
req = Mock(spec=["cookies"])
req.cookies = []
- url = self.get_success(
- self.provider.handle_redirect_request(req, b"http://client/redirect")
+ url = urlparse(
+ self.get_success(
+ self.provider.handle_redirect_request(req, b"http://client/redirect")
+ )
)
- url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
self.assertEqual(url.scheme, auth_endpoint.scheme)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 1ec105c373..f88c725a42 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -59,7 +59,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- self.get_success(self.register_user(self.frank.localpart, "frankpassword"))
+ self.register_user(self.frank.localpart, "frankpassword")
self.handler = hs.get_profile_handler()
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 3aedc0767b..865b8b7e47 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -158,9 +158,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
# Blow away caches (supported room versions can only change due to a restart).
- self.get_success(
- self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
- )
+ self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
self.store._get_event_cache.clear()
# The rooms should be excluded from the sync response.
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 92012cd6f7..c6e501c7be 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -351,6 +351,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.handler.handle_local_profile_change(regular_user_id, profile_info)
)
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
+ assert profile is not None
self.assertTrue(profile["display_name"] == display_name)
def test_handle_local_profile_change_with_deactivated_user(self) -> None:
@@ -369,6 +370,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ assert profile is not None
self.assertTrue(profile["display_name"] == display_name)
# deactivate user
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 10dd94b549..9fd5d59c55 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -87,11 +87,23 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(displayname, "Bobberino")
def test_can_register_admin_user(self):
- user_id = self.get_success(
- self.register_user(
- "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
- )
+ user_id = self.register_user(
+ "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
)
+
+ found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ self.assertEqual(found_user.user_id.to_string(), user_id)
+ self.assertIdentical(found_user.is_admin, True)
+
+ def test_can_set_admin(self):
+ user_id = self.register_user(
+ "alice_wants_admin",
+ "1234",
+ displayname="Alice Powerhungry",
+ admin=False,
+ )
+
+ self.get_success(self.module_api.set_user_admin(user_id, True))
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
@@ -278,7 +290,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Create a user and room to play with
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
- room_id = self.helper.create_room_as(user_id, tok=tok)
+ room_id = self.helper.create_room_as(user_id, tok=tok, is_public=False)
# The room should not currently be in the public rooms directory
is_in_public_rooms = self.get_success(
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 9c5df266bd..a0589b6d6a 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -206,7 +206,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
path: bytes = request.path # type: ignore
self.assertRegex(
path,
- br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
+ rb"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
% (stream_name.encode("ascii"),),
)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 17dc42fd37..297a9e77f8 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -268,7 +268,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_source = RoomEventSource(self.hs)
event_source.store = self.slaved_store
- current_token = self.get_success(event_source.get_current_key())
+ current_token = event_source.get_current_key()
# gradually stream out the replication
while repl_transport.buffer:
@@ -277,7 +277,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.pump(0)
prev_token = current_token
- current_token = self.get_success(event_source.get_current_key())
+ current_token = event_source.get_current_key()
# attempt to replicate the behaviour of the sync handler.
#
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 0d47dd0aff..e909e444ac 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -702,6 +702,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
"""
media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
self.assertFalse(media_info["quarantined_by"])
# quarantining
@@ -715,6 +716,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
self.assertTrue(media_info["quarantined_by"])
# remove from quarantine
@@ -728,6 +730,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
self.assertFalse(media_info["quarantined_by"])
def test_quarantine_protected_media(self) -> None:
@@ -740,6 +743,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
# verify protection
media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"])
# quarantining
@@ -754,6 +758,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
# verify that is not in quarantine
media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
self.assertFalse(media_info["quarantined_by"])
@@ -830,6 +835,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
"""
media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"])
# protect
@@ -843,6 +849,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"])
# unprotect
@@ -856,6 +863,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"])
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 2c855bff99..a53463c9ba 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -214,9 +214,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertEqual(messages[0]["sender"], "@notices:test")
# invalidate cache of server notices room_ids
- self.get_success(
- self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
- )
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
# send second message
channel = self.make_request(
@@ -291,9 +289,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# invalidate cache of server notices room_ids
# if server tries to send to a cached room_id the user gets the message
# in old room
- self.get_success(
- self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
- )
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
# send second message
channel = self.make_request(
@@ -380,9 +376,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# invalidate cache of server notices room_ids
# if server tries to send to a cached room_id it gives an error
- self.get_success(
- self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
- )
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
# send second message
channel = self.make_request(
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index bef911d5df..0cdf1dec40 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1590,10 +1590,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- pushers = self.get_success(
- self.store.get_pushers_by({"user_name": "@bob:test"})
+ pushers = list(
+ self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual("@bob:test", pushers[0].user_name)
@@ -1632,10 +1631,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- pushers = self.get_success(
- self.store.get_pushers_by({"user_name": "@bob:test"})
+ pushers = list(
+ self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 0)
def test_set_password(self) -> None:
@@ -2144,6 +2142,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# is in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+ assert profile is not None
self.assertTrue(profile["display_name"] == "User")
# Deactivate user
@@ -2711,6 +2710,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
user_tuple = self.get_success(
self.store.get_user_by_access_token(other_user_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -3676,6 +3676,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
# The user starts off as not shadow-banned.
other_user_token = self.login("user", "pass")
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ assert result is not None
self.assertFalse(result.shadow_banned)
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
@@ -3684,6 +3685,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
# Ensure the user is shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ assert result is not None
self.assertTrue(result.shadow_banned)
# Un-shadow-ban the user.
@@ -3695,6 +3697,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
# Ensure the user is no longer shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ assert result is not None
self.assertFalse(result.shadow_banned)
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 27946febff..e00b5c171c 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -89,6 +89,17 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
+ def attempt_wrong_password_login(self, username: str, password: str) -> None:
+ """Attempts to login as the user with the given password, asserting
+ that the attempt *fails*.
+ """
+ body = {"type": "m.login.password", "user": username, "password": password}
+
+ channel = self.make_request(
+ "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+
def test_basic_password_reset(self) -> None:
"""Test basic password reset flow"""
old_password = "monkey"
diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py
new file mode 100644
index 0000000000..d5b0640e7a
--- /dev/null
+++ b/tests/rest/client/test_account_data.py
@@ -0,0 +1,75 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+from synapse.rest import admin
+from synapse.rest.client import account_data, login, room
+
+from tests import unittest
+from tests.test_utils import make_awaitable
+
+
+class AccountDataTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ account_data.register_servlets,
+ ]
+
+ def test_on_account_data_updated_callback(self) -> None:
+ """Tests that the on_account_data_updated module callback is called correctly when
+ a user's account data changes.
+ """
+ mocked_callback = Mock(return_value=make_awaitable(None))
+ self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append(
+ mocked_callback
+ )
+
+ user_id = self.register_user("user", "password")
+ tok = self.login("user", "password")
+ account_data_type = "org.matrix.foo"
+ account_data_content = {"bar": "baz"}
+
+ # Change the user's global account data.
+ channel = self.make_request(
+ "PUT",
+ f"/user/{user_id}/account_data/{account_data_type}",
+ account_data_content,
+ access_token=tok,
+ )
+
+ # Test that the callback is called with the user ID, the new account data, and
+ # None as the room ID.
+ self.assertEqual(channel.code, 200, channel.result)
+ mocked_callback.assert_called_once_with(
+ user_id, None, account_data_type, account_data_content
+ )
+
+ # Change the user's room-specific account data.
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+ channel = self.make_request(
+ "PUT",
+ f"/user/{user_id}/rooms/{room_id}/account_data/{account_data_type}",
+ account_data_content,
+ access_token=tok,
+ )
+
+ # Test that the callback is called with the user ID, the room ID and the new
+ # account data.
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(mocked_callback.call_count, 2)
+ mocked_callback.assert_called_with(
+ user_id, room_id, account_data_type, account_data_content
+ )
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index fe97a0b3dd..419eef166a 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import urllib.parse
from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
@@ -145,16 +144,6 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
return channel.json_body["unsigned"].get("m.relations", {})
- def _get_aggregations(self) -> List[JsonDict]:
- """Request /aggregations on the parent ID and includes the returned chunk."""
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- return channel.json_body["chunk"]
-
def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
"""
Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
@@ -264,43 +253,6 @@ class RelationsTestCase(BaseRelationsTestCase):
expected_response_code=400,
)
- def test_aggregation(self) -> None:
- """Test that annotations get correctly aggregated."""
-
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(
- channel.json_body,
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- )
-
- def test_aggregation_must_be_annotation(self) -> None:
- """Test that aggregations must be annotations."""
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations"
- f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(400, channel.code, channel.json_body)
-
def test_ignore_invalid_room(self) -> None:
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
@@ -394,15 +346,6 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
- # And when fetching aggregations.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
-
# And for bundled aggregations.
channel = self.make_request(
"GET",
@@ -717,15 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertNotIn("m.relations", channel.json_body["unsigned"])
- # But unknown relations can be directly queried.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
-
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
@@ -941,131 +875,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
)
- def test_aggregation_pagination_groups(self) -> None:
- """Test that we can paginate annotation groups correctly."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
- for key in itertools.chain.from_iterable(
- itertools.repeat(key, num) for key, num in sent_groups.items()
- ):
- self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key=key,
- access_token=access_tokens[idx],
- )
-
- idx += 1
- idx %= len(access_tokens)
-
- prev_token: Optional[str] = None
- found_groups: Dict[str, int] = {}
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- for groups in channel.json_body["chunk"]:
- # We only expect reactions
- self.assertEqual(groups["type"], "m.reaction", channel.json_body)
-
- # We should only see each key once
- self.assertNotIn(groups["key"], found_groups, channel.json_body)
-
- found_groups[groups["key"]] = groups["count"]
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- self.assertEqual(sent_groups, found_groups)
-
- def test_aggregation_pagination_within_group(self) -> None:
- """Test that we can paginate within an annotation group."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- expected_event_ids = []
- for _ in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="👍",
- access_token=access_tokens[idx],
- )
- expected_event_ids.append(channel.json_body["event_id"])
-
- idx += 1
-
- # Also send a different type of reaction so that we test we don't see it
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-
- prev_token = ""
- found_event_ids: List[str] = []
- encoded_key = urllib.parse.quote_plus("👍".encode())
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}"
- f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- f"/m.reaction/{encoded_key}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
class BundledAggregationsTestCase(BaseRelationsTestCase):
"""
@@ -1453,10 +1262,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
{"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
)
- # Both relations appear in the aggregation.
- chunk = self._get_aggregations()
- self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}])
-
# Redact one of the reactions.
self._redact(to_redact_event_id)
@@ -1469,10 +1274,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
{"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
)
- # The unredacted aggregation should still exist.
- chunk = self._get_aggregations()
- self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])
-
def test_redact_relation_thread(self) -> None:
"""
Test that thread replies are properly handled after the thread reply redacted.
@@ -1578,10 +1379,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self.assertEqual(len(event_ids), 1)
self.assertIn(RelationTypes.ANNOTATION, relations)
- # The aggregation should exist.
- chunk = self._get_aggregations()
- self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}])
-
# Redact the original event.
self._redact(self.parent_id)
@@ -1594,10 +1391,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
{"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
)
- # There's nothing to aggregate.
- chunk = self._get_aggregations()
- self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}])
-
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_redact_parent_thread(self) -> None:
"""
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 3a9617d6da..6ff79b9e2e 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -982,7 +982,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
super().prepare(reactor, clock, hs)
# profile changes expect that the user is actually registered
user = UserID.from_string(self.user_id)
- self.get_success(self.register_user(user.localpart, "supersecretpassword"))
+ self.register_user(user.localpart, "supersecretpassword")
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 4351013952..773c16a54c 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -341,7 +341,6 @@ class SyncKnockTestCase(
hs, self.room_id, self.user_id
)
- @override_config({"experimental_features": {"msc2403_enabled": True}})
def test_knock_room_state(self) -> None:
"""Tests that /sync returns state from a room after knocking on it."""
# Knock on a room
@@ -497,6 +496,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
receipts.register_servlets,
]
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = {"msc2654_enabled": True}
+ return config
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s"
self.next_batch = "s0"
@@ -772,3 +776,65 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
self.assertIn(
self.user_id, device_list_changes, incremental_sync_channel.json_body
)
+
+
+class ExcludeRoomTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self.user_id = self.register_user("user", "password")
+ self.tok = self.login("user", "password")
+
+ self.excluded_room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+ self.included_room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ # We need to manually append the room ID, because we can't know the ID before
+ # creating the room, and we can't set the config after starting the homeserver.
+ self.hs.get_sync_handler().rooms_to_exclude.append(self.excluded_room_id)
+
+ def test_join_leave(self) -> None:
+ """Tests that rooms are correctly excluded from the 'join' and 'leave' sections of
+ sync responses.
+ """
+ channel = self.make_request("GET", "/sync", access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
+ self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
+
+ self.helper.leave(self.excluded_room_id, self.user_id, tok=self.tok)
+ self.helper.leave(self.included_room_id, self.user_id, tok=self.tok)
+
+ channel = self.make_request(
+ "GET",
+ "/sync?since=" + channel.json_body["next_batch"],
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["leave"])
+ self.assertIn(self.included_room_id, channel.json_body["rooms"]["leave"])
+
+ def test_invite(self) -> None:
+ """Tests that rooms are correctly excluded from the 'invite' section of sync
+ responses.
+ """
+ invitee = self.register_user("invitee", "password")
+ invitee_tok = self.login("invitee", "password")
+
+ self.helper.invite(self.excluded_room_id, self.user_id, invitee, tok=self.tok)
+ self.helper.invite(self.included_room_id, self.user_id, invitee, tok=self.tok)
+
+ channel = self.make_request("GET", "/sync", access_token=invitee_tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"])
+ self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"])
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e7de67e3a3..5eb0f243f7 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -896,3 +896,44 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Check that the mock was called with the right room ID
self.assertEqual(args[1], self.room_id)
+
+ def test_on_threepid_bind(self) -> None:
+ """Tests that the on_threepid_bind module callback is called correctly after
+ associating a 3PID to an account.
+ """
+ # Register a mocked callback.
+ threepid_bind_mock = Mock(return_value=make_awaitable(None))
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the shutdown was blocked
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that the mock was called once.
+ threepid_bind_mock.assert_called_once()
+ args = threepid_bind_mock.call_args[0]
+
+ # Check that the mock was called with the right parameters
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 28663826fc..a0788b1bb0 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -88,7 +88,7 @@ class RestHelper:
def create_room_as(
self,
room_creator: Optional[str] = None,
- is_public: Optional[bool] = None,
+ is_public: Optional[bool] = True,
room_version: Optional[str] = None,
tok: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
@@ -101,9 +101,12 @@ class RestHelper:
Args:
room_creator: The user ID to create the room with.
is_public: If True, the `visibility` parameter will be set to
- "public". If False, it will be set to "private". If left
- unspecified, the server will set it to an appropriate default
- (which should be "private" as per the CS spec).
+ "public". If False, it will be set to "private".
+ If None, doesn't specify the `visibility` parameter in which
+ case the server is supposed to make the room private according to
+ the CS API.
+ Defaults to public, since that is commonly needed in tests
+ for convenience where room privacy is not a problem.
room_version: The room version to create the room as. Defaults to Synapse's
default room version.
tok: The access token to use in the request.
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 978c252f84..ac0ac06b7e 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -76,7 +76,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
"verify_keys": {
key_id: {
"key": signedjson.key.encode_verify_key_base64(
- signing_key.verify_key
+ signedjson.key.get_verify_key(signing_key)
)
}
},
@@ -175,7 +175,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
% (
self.hs_signing_key.version,
): signedjson.key.encode_verify_key_base64(
- self.hs_signing_key.verify_key
+ signedjson.key.get_verify_key(self.hs_signing_key)
)
},
}
@@ -229,7 +229,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
assert isinstance(keyres, FetchKeyResult)
self.assertEqual(
signedjson.key.encode_verify_key_base64(keyres.verify_key),
- signedjson.key.encode_verify_key_base64(testkey.verify_key),
+ signedjson.key.encode_verify_key_base64(
+ signedjson.key.get_verify_key(testkey)
+ ),
)
def test_get_notary_key(self) -> None:
@@ -251,7 +253,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
assert isinstance(keyres, FetchKeyResult)
self.assertEqual(
signedjson.key.encode_verify_key_base64(keyres.verify_key),
- signedjson.key.encode_verify_key_base64(testkey.verify_key),
+ signedjson.key.encode_verify_key_base64(
+ signedjson.key.get_verify_key(testkey)
+ ),
)
def test_get_notary_keyserver_key(self) -> None:
@@ -268,5 +272,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
assert isinstance(keyres, FetchKeyResult)
self.assertEqual(
signedjson.key.encode_verify_key_base64(keyres.verify_key),
- signedjson.key.encode_verify_key_base64(self.hs_signing_key.verify_key),
+ signedjson.key.encode_verify_key_base64(
+ signedjson.key.get_verify_key(self.hs_signing_key)
+ ),
)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 5148c39874..3b24d0ace6 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -17,7 +17,7 @@ import json
import os
import re
from typing import Any, Dict, Optional, Sequence, Tuple, Type
-from urllib.parse import urlencode
+from urllib.parse import quote, urlencode
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
@@ -69,7 +69,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"2001:800::/21",
)
config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
- config["url_preview_url_blacklist"] = []
config["url_preview_accept_language"] = [
"en-UK",
"en-US;q=0.9",
@@ -1123,3 +1122,43 @@ class URLPreviewTests(unittest.HomeserverTestCase):
os.path.exists(path),
f"{os.path.relpath(path, self.media_store_path)} was not deleted",
)
+
+ @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]})
+ def test_blacklist_port(self) -> None:
+ """Tests that blacklisting URLs with a port makes previewing such URLs
+ fail with a 403 error and doesn't impact other previews.
+ """
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ bad_url = quote("http://matrix.org:8888/foo")
+ good_url = quote("http://matrix.org/foo")
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=" + bad_url,
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 403, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=" + good_url,
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
diff --git a/tests/server.py b/tests/server.py
index 6ce2a17bf4..aaa5ca3e74 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -22,7 +22,6 @@ import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
- AnyStr,
Callable,
Dict,
Iterable,
@@ -86,6 +85,9 @@ from tests.utils import (
logger = logging.getLogger(__name__)
+# the type of thing that can be passed into `make_request` in the headers list
+CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
+
class TimedOutException(Exception):
"""
@@ -260,7 +262,7 @@ def make_request(
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
- custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+ custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 3ac4646969..74c6224eb6 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -28,7 +28,7 @@ class LockTestCase(unittest.HomeserverTestCase):
"""
# First to acquire this lock, so it should complete
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
- self.assertIsNotNone(lock)
+ assert lock is not None
# Enter the context manager
self.get_success(lock.__aenter__())
@@ -45,7 +45,7 @@ class LockTestCase(unittest.HomeserverTestCase):
# We can now acquire the lock again.
lock3 = self.get_success(self.store.try_acquire_lock("name", "key"))
- self.assertIsNotNone(lock3)
+ assert lock3 is not None
self.get_success(lock3.__aenter__())
self.get_success(lock3.__aexit__(None, None, None))
@@ -53,7 +53,7 @@ class LockTestCase(unittest.HomeserverTestCase):
"""Test that we don't time out locks while they're still active"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
- self.assertIsNotNone(lock)
+ assert lock is not None
self.get_success(lock.__aenter__())
@@ -69,7 +69,7 @@ class LockTestCase(unittest.HomeserverTestCase):
"""Test that we time out locks if they're not updated for ages"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
- self.assertIsNotNone(lock)
+ assert lock is not None
self.get_success(lock.__aenter__())
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ee599f4336..1bf93e79a7 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
+from synapse.types import DeviceListUpdates
from synapse.util import Clock
from tests import unittest
@@ -168,15 +169,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
(as_id, txn_id, json.dumps([e.event_id for e in events])),
)
- def _set_last_txn(self, as_id, txn_id):
- return self.db_pool.runOperation(
- self.engine.convert_param_style(
- "INSERT INTO application_services_state(as_id, last_txn, state) "
- "VALUES(?,?,?)"
- ),
- (as_id, txn_id, ApplicationServiceState.UP.value),
- )
-
def test_get_appservice_state_none(
self,
) -> None:
@@ -267,65 +259,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success(
defer.ensureDeferred(
- self.store.create_appservice_txn(service, events, [], [], {}, {})
+ self.store.create_appservice_txn(
+ service, events, [], [], {}, {}, DeviceListUpdates()
+ )
)
)
self.assertEqual(txn.id, 1)
self.assertEqual(txn.events, events)
self.assertEqual(txn.service, service)
- def test_create_appservice_txn_older_last_txn(
- self,
- ) -> None:
- service = Mock(id=self.as_list[0]["id"])
- events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
- self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind
- self.get_success(self._insert_txn(service.id, 9644, events))
- self.get_success(self._insert_txn(service.id, 9645, events))
- txn = self.get_success(
- self.store.create_appservice_txn(service, events, [], [], {}, {})
- )
- self.assertEqual(txn.id, 9646)
- self.assertEqual(txn.events, events)
- self.assertEqual(txn.service, service)
-
- def test_create_appservice_txn_up_to_date_last_txn(
- self,
- ) -> None:
- service = Mock(id=self.as_list[0]["id"])
- events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
- self.get_success(self._set_last_txn(service.id, 9643))
- txn = self.get_success(
- self.store.create_appservice_txn(service, events, [], [], {}, {})
- )
- self.assertEqual(txn.id, 9644)
- self.assertEqual(txn.events, events)
- self.assertEqual(txn.service, service)
-
- def test_create_appservice_txn_up_fuzzing(
- self,
- ) -> None:
- service = Mock(id=self.as_list[0]["id"])
- events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
- self.get_success(self._set_last_txn(service.id, 9643))
-
- # dump in rows with higher IDs to make sure the queries aren't wrong.
- self.get_success(self._set_last_txn(self.as_list[1]["id"], 119643))
- self.get_success(self._set_last_txn(self.as_list[2]["id"], 9))
- self.get_success(self._set_last_txn(self.as_list[3]["id"], 9643))
- self.get_success(self._insert_txn(self.as_list[1]["id"], 119644, events))
- self.get_success(self._insert_txn(self.as_list[1]["id"], 119645, events))
- self.get_success(self._insert_txn(self.as_list[1]["id"], 119646, events))
- self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events))
- self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
-
- txn = self.get_success(
- self.store.create_appservice_txn(service, events, [], [], {}, {})
- )
- self.assertEqual(txn.id, 9644)
- self.assertEqual(txn.events, events)
- self.assertEqual(txn.service, service)
-
def test_complete_appservice_txn_first_txn(
self,
) -> None:
@@ -359,13 +301,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(0, len(res))
- def test_complete_appservice_txn_existing_in_state_table(
+ def test_complete_appservice_txn_updates_last_txn_state(
self,
) -> None:
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn_id = 5
- self.get_success(self._set_last_txn(service.id, 4))
+ self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
self.get_success(self._insert_txn(service.id, txn_id, events))
self.get_success(
self.store.complete_appservice_txn(txn_id=txn_id, service=service)
@@ -416,6 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(service.id, 12, other_events))
txn = self.get_success(self.store.get_oldest_unsent_txn(service))
+ assert txn is not None
self.assertEqual(service, txn.service)
self.assertEqual(10, txn.id)
self.assertEqual(events, txn.events)
@@ -476,12 +419,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
value = self.get_success(
self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
)
- self.assertEqual(value, 0)
+ self.assertEqual(value, 1)
value = self.get_success(
self.store.get_type_stream_id_for_appservice(self.service, "presence")
)
- self.assertEqual(value, 0)
+ self.assertEqual(value, 1)
def test_get_type_stream_id_for_appservice_invalid_type(self) -> None:
self.get_failure(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index ce89c96912..b998ad42d9 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -68,6 +68,22 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.wait_for_background_updates()
+ def add_extremity(self, room_id: str, event_id: str) -> None:
+ """
+ Add the given event as an extremity to the room.
+ """
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_insert(
+ table="event_forward_extremities",
+ values={"room_id": room_id, "event_id": event_id},
+ desc="test_add_extremity",
+ )
+ )
+
+ self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate(
+ (room_id,)
+ )
+
def test_soft_failed_extremities_handled_correctly(self):
"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -250,7 +266,9 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
self.requester = create_requester(self.user)
- info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
+ info, _ = self.get_success(
+ self.room_creator.create_room(self.requester, {"visibility": "public"})
+ )
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 21ffc5a909..d1227dd4ac 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -96,7 +96,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Add two device updates with sequential `stream_id`s
self.get_success(
- self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ self.store.add_device_change_to_streams(
+ "user_id", device_ids, ["somehost"], ["!some:room"]
+ )
)
# Get all device updates ever meant for this remote
@@ -122,7 +124,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
"device_id5",
]
self.get_success(
- self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ self.store.add_device_change_to_streams(
+ "user_id", device_ids, ["somehost"], ["!some:room"]
+ )
)
# Get device updates meant for this remote
@@ -144,7 +148,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Add some more device updates to ensure it still resumes properly
device_ids = ["device_id6", "device_id7"]
self.get_success(
- self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ self.store.add_device_change_to_streams(
+ "user_id", device_ids, ["somehost"], ["!some:room"]
+ )
)
# Get the next batch of device updates
@@ -220,7 +226,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
self.get_success(
self.store.add_device_change_to_streams(
- "@user_id:test", device_ids, ["somehost"]
+ "@user_id:test", device_ids, ["somehost"], ["!some:room"]
)
)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 395396340b..2d8d1f860f 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -157,10 +157,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- ctx1 = self.get_success(id_gen.get_next())
- ctx2 = self.get_success(id_gen.get_next())
- ctx3 = self.get_success(id_gen.get_next())
- ctx4 = self.get_success(id_gen.get_next())
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
+ ctx3 = id_gen.get_next()
+ ctx4 = id_gen.get_next()
s1 = self.get_success(ctx1.__aenter__())
s2 = self.get_success(ctx2.__aenter__())
@@ -362,8 +362,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Persist two rows at once
- ctx1 = self.get_success(id_gen.get_next())
- ctx2 = self.get_success(id_gen.get_next())
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
s1 = self.get_success(ctx1.__aenter__())
s2 = self.get_success(ctx2.__aenter__())
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 03e9cc7d4a..d8d17ef379 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -119,11 +119,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
return event
def test_redact(self):
- self.get_success(
- self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- )
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+ msg_event = self.inject_message(self.room1, self.u_alice, "t")
# Check event has not been redacted:
event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -141,9 +139,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
# Redact event
reason = "Because I said so"
- self.get_success(
- self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
- )
+ self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -170,14 +166,10 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
def test_redact_join(self):
- self.get_success(
- self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- )
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- msg_event = self.get_success(
- self.inject_room_member(
- self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
- )
+ msg_event = self.inject_room_member(
+ self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
)
event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -195,9 +187,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
# Redact event
reason = "Because I said so"
- self.get_success(
- self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
- )
+ self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
# Check redaction
@@ -311,11 +301,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def test_redact_censor(self):
"""Test that a redacted event gets censored in the DB after a month"""
- self.get_success(
- self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- )
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+ msg_event = self.inject_message(self.room1, self.u_alice, "t")
# Check event has not been redacted:
event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -333,9 +321,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
# Redact event
reason = "Because I said so"
- self.get_success(
- self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
- )
+ self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -381,25 +367,19 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def test_redact_redaction(self):
"""Tests that we can redact a redaction and can fetch it again."""
- self.get_success(
- self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- )
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+ msg_event = self.inject_message(self.room1, self.u_alice, "t")
- first_redact_event = self.get_success(
- self.inject_redaction(
- self.room1, msg_event.event_id, self.u_alice, "Redacting message"
- )
+ first_redact_event = self.inject_redaction(
+ self.room1, msg_event.event_id, self.u_alice, "Redacting message"
)
- self.get_success(
- self.inject_redaction(
- self.room1,
- first_redact_event.event_id,
- self.u_alice,
- "Redacting redaction",
- )
+ self.inject_redaction(
+ self.room1,
+ first_redact_event.event_id,
+ self.u_alice,
+ "Redacting redaction",
)
# Now lets jump to the future where we have censored the redaction event
@@ -414,9 +394,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def test_store_redacted_redaction(self):
"""Tests that we can store a redacted redaction."""
- self.get_success(
- self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- )
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index eaa0d7d749..52e41cdab4 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -110,9 +110,7 @@ class PaginationTestCase(HomeserverTestCase):
def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
"""Make a request to /messages with a filter, returns the chunk of events."""
- from_token = self.get_success(
- self.hs.get_event_sources().get_current_token_for_pagination()
- )
+ from_token = self.hs.get_event_sources().get_current_token_for_pagination()
events, next_key = self.get_success(
self.hs.get_datastores().main.paginate_room_events(
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 532e3fe9cd..d0230f9ebb 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -17,6 +17,7 @@ from unittest.mock import patch
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.types import JsonDict, create_requester
from synapse.visibility import filter_events_for_client, filter_events_for_server
@@ -47,17 +48,15 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
#
# before we do that, we persist some other events to act as state.
- self.get_success(self._inject_visibility("@admin:hs", "joined"))
+ self._inject_visibility("@admin:hs", "joined")
for i in range(0, 10):
- self.get_success(self._inject_room_member("@resident%i:hs" % i))
+ self._inject_room_member("@resident%i:hs" % i)
events_to_filter = []
for i in range(0, 10):
user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
- evt = self.get_success(
- self._inject_room_member(user, extra_content={"a": "b"})
- )
+ evt = self._inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
filtered = self.get_success(
@@ -73,24 +72,57 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertEqual(filtered[i].content["a"], "b")
+ def test_filter_outlier(self) -> None:
+ # outlier events must be returned, for the good of the collective federation
+ self._inject_room_member("@resident:remote_hs")
+ self._inject_visibility("@resident:remote_hs", "joined")
+
+ outlier = self._inject_outlier()
+ self.assertEqual(
+ self.get_success(
+ filter_events_for_server(self.storage, "remote_hs", [outlier])
+ ),
+ [outlier],
+ )
+
+ # it should also work when there are other events in the list
+ evt = self._inject_message("@unerased:local_hs")
+
+ filtered = self.get_success(
+ filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
+ )
+ self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
+ self.assertEqual(filtered[0], outlier)
+ self.assertEqual(filtered[1].event_id, evt.event_id)
+ self.assertEqual(filtered[1].content, evt.content)
+
+ # ... but other servers should only be able to see the outlier (the other should
+ # be redacted)
+ filtered = self.get_success(
+ filter_events_for_server(self.storage, "other_server", [outlier, evt])
+ )
+ self.assertEqual(filtered[0], outlier)
+ self.assertEqual(filtered[1].event_id, evt.event_id)
+ self.assertNotIn("body", filtered[1].content)
+
def test_erased_user(self) -> None:
# 4 message events, from erased and unerased users, with a membership
# change in the middle of them.
events_to_filter = []
- evt = self.get_success(self._inject_message("@unerased:local_hs"))
+ evt = self._inject_message("@unerased:local_hs")
events_to_filter.append(evt)
- evt = self.get_success(self._inject_message("@erased:local_hs"))
+ evt = self._inject_message("@erased:local_hs")
events_to_filter.append(evt)
- evt = self.get_success(self._inject_room_member("@joiner:remote_hs"))
+ evt = self._inject_room_member("@joiner:remote_hs")
events_to_filter.append(evt)
- evt = self.get_success(self._inject_message("@unerased:local_hs"))
+ evt = self._inject_message("@unerased:local_hs")
events_to_filter.append(evt)
- evt = self.get_success(self._inject_message("@erased:local_hs"))
+ evt = self._inject_message("@erased:local_hs")
events_to_filter.append(evt)
# the erasey user gets erased
@@ -187,6 +219,25 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.get_success(self.storage.persistence.persist_event(event, context))
return event
+ def _inject_outlier(self) -> EventBase:
+ builder = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": "m.room.member",
+ "sender": "@test:user",
+ "state_key": "@test:user",
+ "room_id": TEST_ROOM_ID,
+ "content": {"membership": "join"},
+ },
+ )
+
+ event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
+ event.internal_metadata.outlier = True
+ self.get_success(
+ self.storage.persistence.persist_event(event, EventContext.for_outlier())
+ )
+ return event
+
class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
def test_out_of_band_invite_rejection(self):
diff --git a/tests/unittest.py b/tests/unittest.py
index 326895f4c9..9afa68c164 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -16,17 +16,17 @@
import gc
import hashlib
import hmac
-import inspect
import json
import logging
import secrets
import time
from typing import (
Any,
- AnyStr,
+ Awaitable,
Callable,
ClassVar,
Dict,
+ Generic,
Iterable,
List,
Optional,
@@ -40,6 +40,7 @@ from unittest.mock import Mock, patch
import canonicaljson
import signedjson.key
import unpaddedbase64
+from typing_extensions import Protocol
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
@@ -50,7 +51,7 @@ from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse import events
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
@@ -71,7 +72,13 @@ from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
-from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
+from tests.server import (
+ CustomHeaderType,
+ FakeChannel,
+ get_clock,
+ make_request,
+ setup_test_homeserver,
+)
from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -79,6 +86,17 @@ from tests.utils import default_config, setupdb
setupdb()
setup_logging()
+TV = TypeVar("TV")
+_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
+
+
+class _TypedFailure(Generic[_ExcType], Protocol):
+ """Extension to twisted.Failure, where the 'value' has a certain type."""
+
+ @property
+ def value(self) -> _ExcType:
+ ...
+
def around(target):
"""A CLOS-style 'around' modifier, which wraps the original method of the
@@ -277,6 +295,7 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
+ assert self.helper.auth_user_id is not None
# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
@@ -289,6 +308,7 @@ class HomeserverTestCase(TestCase):
)
async def get_user_by_access_token(token=None, allow_guest=False):
+ assert self.helper.auth_user_id is not None
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": token_id,
@@ -296,6 +316,7 @@ class HomeserverTestCase(TestCase):
}
async def get_user_by_req(request, allow_guest=False, rights="access"):
+ assert self.helper.auth_user_id is not None
return create_requester(
UserID.from_string(self.helper.auth_user_id),
token_id,
@@ -312,7 +333,7 @@ class HomeserverTestCase(TestCase):
)
if self.needs_threadpool:
- self.reactor.threadpool = ThreadPool()
+ self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
self.addCleanup(self.reactor.threadpool.stop)
self.reactor.threadpool.start()
@@ -427,7 +448,7 @@ class HomeserverTestCase(TestCase):
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
- custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+ custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
@@ -512,40 +533,36 @@ class HomeserverTestCase(TestCase):
return hs
- def pump(self, by=0.0):
+ def pump(self, by: float = 0.0) -> None:
"""
Pump the reactor enough that Deferreds will fire.
"""
self.reactor.pump([by] * 100)
- def get_success(self, d, by=0.0):
- if inspect.isawaitable(d):
- d = ensureDeferred(d)
- if not isinstance(d, Deferred):
- return d
+ def get_success(
+ self,
+ d: Awaitable[TV],
+ by: float = 0.0,
+ ) -> TV:
+ deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
self.pump(by=by)
- return self.successResultOf(d)
+ return self.successResultOf(deferred)
- def get_failure(self, d, exc):
+ def get_failure(
+ self, d: Awaitable[Any], exc: Type[_ExcType]
+ ) -> _TypedFailure[_ExcType]:
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
- if inspect.isawaitable(d):
- d = ensureDeferred(d)
- if not isinstance(d, Deferred):
- return d
+ deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type]
self.pump()
- return self.failureResultOf(d, exc)
+ return self.failureResultOf(deferred, exc)
- def get_success_or_raise(self, d, by=0.0):
+ def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
"""Drive deferred to completion and return result or raise exception
on failure.
"""
-
- if inspect.isawaitable(d):
- deferred = ensureDeferred(d)
- if not isinstance(deferred, Deferred):
- return d
+ deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
results: list = []
deferred.addBoth(results.append)
@@ -653,11 +670,11 @@ class HomeserverTestCase(TestCase):
def login(
self,
- username,
- password,
- device_id=None,
- custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
- ):
+ username: str,
+ password: str,
+ device_id: Optional[str] = None,
+ custom_headers: Optional[Iterable[CustomHeaderType]] = None,
+ ) -> str:
"""
Log in a user, and get an access token. Requires the Login API be
registered.
@@ -679,18 +696,22 @@ class HomeserverTestCase(TestCase):
return access_token
def create_and_send_event(
- self, room_id, user, soft_failed=False, prev_event_ids=None
- ):
+ self,
+ room_id: str,
+ user: UserID,
+ soft_failed: bool = False,
+ prev_event_ids: Optional[List[str]] = None,
+ ) -> str:
"""
Create and send an event.
Args:
- soft_failed (bool): Whether to create a soft failed event or not
- prev_event_ids (list[str]|None): Explicitly set the prev events,
+ soft_failed: Whether to create a soft failed event or not
+ prev_event_ids: Explicitly set the prev events,
or if None just use the default
Returns:
- str: The new event's ID.
+ The new event's ID.
"""
event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user)
@@ -717,34 +738,7 @@ class HomeserverTestCase(TestCase):
return event.event_id
- def add_extremity(self, room_id, event_id):
- """
- Add the given event as an extremity to the room.
- """
- self.get_success(
- self.hs.get_datastores().main.db_pool.simple_insert(
- table="event_forward_extremities",
- values={"room_id": room_id, "event_id": event_id},
- desc="test_add_extremity",
- )
- )
-
- self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate(
- (room_id,)
- )
-
- def attempt_wrong_password_login(self, username, password):
- """Attempts to login as the user with the given password, asserting
- that the attempt *fails*.
- """
- body = {"type": "m.login.password", "user": username, "password": password}
-
- channel = self.make_request(
- "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
- )
- self.assertEqual(channel.code, 403, channel.result)
-
- def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
+ def inject_room_member(self, room: str, user: str, membership: str) -> None:
"""
Inject a membership event into a room.
@@ -804,7 +798,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
path: str,
content: Optional[JsonDict] = None,
await_result: bool = True,
- custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+ custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""Make an inbound signed federation request to this server
@@ -837,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
self.site,
method=method,
path=path,
- content=content,
+ content=content or "",
shorthand=False,
await_result=await_result,
custom_headers=custom_headers,
@@ -916,9 +910,6 @@ def override_config(extra_config):
return decorator
-TV = TypeVar("TV")
-
-
def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
"""A test decorator which will skip the decorated test unless a condition is set
|