diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 6e403a87c5..11ad44223d 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -12,14 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.api.errors
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client import directory, login, room
-from synapse.types import RoomAlias, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomAlias, create_requester
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the directory service."""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_registry = Mock()
- self.query_handlers = {}
+ self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
- def register_query_handler(query_type, handler):
+ def register_query_handler(
+ query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+ ) -> None:
self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler
@@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
return hs
- def test_get_local_association(self):
+ def test_get_local_association(self) -> None:
self.get_success(
self.store.create_room_alias_association(
self.my_room, "!8765qwer:test", ["test"]
@@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
- def test_get_remote_association(self):
+ def test_get_remote_association(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
@@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
ignore_backoff=True,
)
- def test_incoming_fed_query(self):
+ def test_incoming_fed_query(self) -> None:
self.get_success(
self.store.create_room_alias_association(
self.your_room, "!8765asdf:test", ["test"]
@@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_directory_handler()
# Create user
@@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
- def test_create_alias_joined_room(self):
+ def test_create_alias_joined_room(self) -> None:
"""A user can create an alias for a room they're in."""
self.get_success(
self.handler.create_association(
@@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
)
)
- def test_create_alias_other_room(self):
+ def test_create_alias_other_room(self) -> None:
"""A user cannot create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
@@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_create_alias_admin(self):
+ def test_create_alias_admin(self) -> None:
"""An admin can create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as(
self.test_user, tok=self.test_user_tok
@@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
@@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
- def _create_alias(self, user):
+ def _create_alias(self, user) -> None:
# Create a new alias to this room.
self.get_success(
self.store.create_room_alias_association(
@@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
)
)
- def test_delete_alias_not_allowed(self):
+ def test_delete_alias_not_allowed(self) -> None:
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
self._create_alias(self.admin_user)
self.get_failure(
@@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.AuthError,
)
- def test_delete_alias_creator(self):
+ def test_delete_alias_creator(self) -> None:
"""An alias creator can delete their own alias."""
# Create an alias from a different user.
self._create_alias(self.test_user)
@@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_delete_alias_admin(self):
+ def test_delete_alias_admin(self) -> None:
"""A server admin can delete an alias created by another user."""
# Create an alias from a different user.
self._create_alias(self.test_user)
@@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_delete_alias_sufficient_power(self):
+ def test_delete_alias_sufficient_power(self) -> None:
"""A user with a sufficient power level should be able to delete an alias."""
self._create_alias(self.admin_user)
@@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
@@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
return room_alias
- def _set_canonical_alias(self, content):
+ def _set_canonical_alias(self, content) -> None:
"""Configure the canonical alias state on the room."""
self.helper.send_state(
self.room_id,
@@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
)
- def test_remove_alias(self):
+ def test_remove_alias(self) -> None:
"""Removing an alias that is the canonical alias should remove it there too."""
# Set this new alias as the canonical alias for this room
self._set_canonical_alias(
@@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertNotIn("alias", data["content"])
self.assertNotIn("alt_aliases", data["content"])
- def test_remove_other_alias(self):
+ def test_remove_other_alias(self) -> None:
"""Removing an alias listed as in alt_aliases should remove it there too."""
# Create a second alias.
other_test_alias = "#test2:test"
@@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# Add custom alias creation rules to the config.
@@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
return config
- def test_denied(self):
+ def test_denied(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
)
self.assertEqual(403, channel.code, channel.result)
- def test_allowed(self):
+ def test_allowed(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, channel.result)
- def test_denied_during_creation(self):
+ def test_denied_during_creation(self) -> None:
"""A room alias that is not allowed should be rejected during creation."""
# Invalid room alias.
self.helper.create_room_as(
@@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
extra_content={"room_alias_name": "foo"},
)
- def test_allowed_during_creation(self):
+ def test_allowed_during_creation(self) -> None:
"""A valid room alias should be allowed during creation."""
room_id = self.helper.create_room_as(
self.user_id,
@@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
data = {"room_alias_name": "unofficial_test"}
allowed_localpart = "allowed"
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# Add custom room list publication rules to the config.
@@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+ ) -> HomeServer:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
@@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return hs
- def test_denied_without_publication_permission(self):
+ def test_denied_without_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user without permission to publish rooms.
@@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403,
)
- def test_allowed_when_creating_private_room(self):
+ def test_allowed_when_creating_private_room(self) -> None:
"""
Try to create a room, register an alias for it, and NOT publish it,
as a user without permission to publish rooms.
@@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200,
)
- def test_allowed_with_publication_permission(self):
+ def test_allowed_with_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
@@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200,
)
- def test_denied_publication_with_invalid_alias(self):
+ def test_denied_publication_with_invalid_alias(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
@@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403,
)
- def test_can_create_as_private_room_after_rejection(self):
+ def test_can_create_as_private_room_after_rejection(self) -> None:
"""
After failing to publish a room with an alias as a user without publish permission,
retry as the same user, but without publishing the room.
@@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
self.test_denied_without_publication_permission()
self.test_allowed_when_creating_private_room()
- def test_can_create_with_permission_after_rejection(self):
+ def test_can_create_with_permission_after_rejection(self) -> None:
"""
After failing to publish a room with an alias as a user without publish permission,
retry as someone with permission, using the same alias.
@@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def prepare(self, reactor, clock, hs):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+ ) -> HomeServer:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
return hs
- def test_disabling_room_list(self):
+ def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 9338ab92e9..ac21a28c43 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -20,33 +20,37 @@ from parameterized import parameterized
from signedjson import key as key, sign as sign
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=mock.Mock())
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastores().main
- def test_query_local_devices_no_devices(self):
+ def test_query_local_devices_no_devices(self) -> None:
"""If the user has no devices, we expect an empty list."""
local_user = "@boris:" + self.hs.hostname
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- def test_reupload_one_time_keys(self):
+ def test_reupload_one_time_keys(self) -> None:
"""we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
- keys = {
+ keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
@@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
)
- def test_change_one_time_keys(self):
+ def test_change_one_time_keys(self) -> None:
"""attempts to change one-time-keys should be rejected"""
local_user = "@boris:" + self.hs.hostname
@@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_claim_one_time_key(self):
+ def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {"alg1:k1": "key1"}
@@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_fallback_key(self):
+ def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"}
@@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
- def test_replace_master_key(self):
+ def test_replace_master_key(self) -> None:
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
- def test_reupload_signatures(self):
+ def test_reupload_signatures(self) -> None:
"""re-uploading a signature should not fail"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
- def test_self_signing_key_doesnt_show_up_as_device(self):
+ def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
"""signing keys should be hidden when fetching a user's devices"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- def test_upload_signatures(self):
+ def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded"""
# set up a user with cross-signing keys and a device. This user will
# try uploading signatures
@@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
)
- def test_query_devices_remote_no_sync(self):
+ def test_query_devices_remote_no_sync(self) -> None:
"""Tests that querying keys for a remote user that we don't share a room
with returns the cross signing keys correctly.
"""
@@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_query_devices_remote_sync(self):
+ def test_query_devices_remote_sync(self) -> None:
"""Tests that querying keys for a remote user that we share a room with,
but haven't yet fetched the keys for, returns the cross signing keys
correctly.
@@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
(["device_1", "device_2"],),
]
)
- def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
"""Test that requests for all of a remote user's devices are cached.
We do this by asserting that only one call over federation was made, and that
@@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
"""
local_user_id = "@test:test"
remote_user_id = "@test:other"
- request_body = {"device_keys": {remote_user_id: []}}
+ request_body: JsonDict = {"device_keys": {remote_user_id: []}}
response_devices = [
{
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e8418b6638..014815db6e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,14 +13,18 @@
# limitations under the License.
import json
import os
+from typing import Any, Dict
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
import pymacaroons
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-async def get_json(url):
+async def get_json(url: str) -> JsonDict:
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
# Minimal discovery document, as defined in OpenID.Discovery
@@ -116,6 +120,8 @@ async def get_json(url):
elif url == JWKS_URI:
return {"keys": []}
+ return {}
+
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
@@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC:
skip = "requires OIDC"
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = b"Synapse Test"
@@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
- sso_handler.render_error = self.render_error
+ sso_handler.render_error = self.render_error # type: ignore[assignment]
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
@@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
return args
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_config(self):
+ def test_config(self) -> None:
"""Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
- def test_discovery(self):
+ def test_discovery(self) -> None:
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
@@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
- def test_no_discovery(self):
+ def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
- def test_load_jwks(self):
+ def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
@@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_validate_config(self):
+ def test_validate_config(self) -> None:
"""Provider metadatas are extensively validated."""
h = self.provider
@@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
force_load_metadata()
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
- def test_skip_verification(self):
+ def test_skip_verification(self) -> None:
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
get_awaitable_result(self.provider.load_metadata())
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_redirect_request(self):
+ def test_redirect_request(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
req.cookies = []
@@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(redirect, "http://client/redirect")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback_error(self):
+ def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed."""
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
@@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_client", "some description")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback(self):
+ def test_callback(self) -> None:
"""Code callback works and display errors if something went wrong.
A lot of scenarios are tested here:
@@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
- self.provider._exchange_code = simple_async_mock(return_value=token)
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
- self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
+ self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.provider._parse_id_token = simple_async_mock(raises=Exception())
+ self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
@@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"type": "bearer",
"access_token": "access_token",
}
- self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
@@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
id_token = {
"sid": "abcdefgh",
}
- self.provider._parse_id_token = simple_async_mock(return_value=id_token)
- self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
auth_handler.complete_sso_login.reset_mock()
self.provider._fetch_userinfo.reset_mock()
self.get_success(self.handler.handle_oidc_callback(request))
@@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
+ self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
from synapse.handlers.oidc import OidcError
- self.provider._exchange_code = simple_async_mock(
+ self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
raises=OidcError("invalid_request")
)
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback_session(self):
+ def test_callback_session(self) -> None:
"""The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"])
@@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
)
- def test_exchange_code(self):
+ def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
token_json = json.dumps(token).encode("utf-8")
@@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_exchange_code_jwt_key(self):
+ def test_exchange_code_jwt_key(self) -> None:
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
@@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_exchange_code_no_auth(self):
+ def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret."""
token = {"type": "bearer"}
self.http_client.request = simple_async_mock(
@@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_extra_attributes(self):
+ def test_extra_attributes(self) -> None:
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
@@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "foo",
"phone": "1234567",
}
- self.provider._exchange_code = simple_async_mock(return_value=token)
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_map_userinfo_to_user(self):
+ def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
- userinfo = {
+ userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
@@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
- def test_map_userinfo_to_existing_user(self):
+ def test_map_userinfo_to_existing_user(self) -> None:
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastores().main
user = UserID.from_string("@test_user:test")
@@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_map_userinfo_to_invalid_localpart(self):
+ def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "fΓΆΓΆ"})
@@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_map_userinfo_to_user_retries(self):
+ def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_empty_localpart(self):
+ def test_empty_localpart(self) -> None:
"""Attempts to map onto an empty localpart should be rejected."""
userinfo = {
"sub": "tester",
@@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_null_localpart(self):
+ def test_null_localpart(self) -> None:
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
userinfo = {
"sub": "tester",
@@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements(self):
+ def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements_contains(self):
+ def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements_mismatch(self):
+ def test_attribute_requirements_mismatch(self) -> None:
"""
Test that auth fails if attributes exist but don't match,
or are non-string values.
@@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail
- userinfo = {
+ userinfo: dict = {
"sub": "tester",
"username": "tester",
"test": "not_foobar",
@@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
handler = hs.get_oidc_handler()
provider = handler._providers["oidc"]
- provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
- provider._parse_id_token = simple_async_mock(return_value=userinfo)
- provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
+ provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
+ provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
state = "state"
session = handler._token_generator.generate_oidc_session_token(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 972cbac6e4..08733a9f2d 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.types
from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_registry = Mock()
- self.query_handlers = {}
+ self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
- def register_query_handler(query_type, handler):
+ def register_query_handler(
+ query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+ ) -> None:
self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler
@@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
return hs
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.frank = UserID.from_string("@1234abcd:test")
@@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_profile_handler()
- def test_get_my_name(self):
+ def test_get_my_name(self) -> None:
self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
@@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("Frank", displayname)
- def test_set_my_name(self):
+ def test_set_my_name(self) -> None:
self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
@@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
)
- def test_set_my_name_if_disabled(self):
+ def test_set_my_name_if_disabled(self) -> None:
self.hs.config.registration.enable_set_displayname = False
# Setting displayname for the first time is allowed
@@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_set_my_name_noauth(self):
+ def test_set_my_name_noauth(self) -> None:
self.get_failure(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
@@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
AuthError,
)
- def test_get_other_name(self):
+ def test_get_other_name(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
@@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
ignore_backoff=True,
)
- def test_incoming_fed_query(self):
+ def test_incoming_fed_query(self) -> None:
self.get_success(self.store.create_profile("caroline"))
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
@@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual({"displayname": "Caroline"}, response)
- def test_get_my_avatar(self):
+ def test_get_my_avatar(self) -> None:
self.get_success(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
@@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("http://my.server/me.png", avatar_url)
- def test_set_my_avatar(self):
+ def test_set_my_avatar(self) -> None:
self.get_success(
self.handler.set_avatar_url(
self.frank,
@@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
)
- def test_set_my_avatar_if_disabled(self):
+ def test_set_my_avatar_if_disabled(self) -> None:
self.hs.config.registration.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
@@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_avatar_constraints_no_config(self):
+ def test_avatar_constraints_no_config(self) -> None:
"""Tests that the method to check an avatar against configured constraints skips
all of its check if no constraint is configured.
"""
@@ -263,7 +268,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertTrue(res)
@unittest.override_config({"max_avatar_size": 50})
- def test_avatar_constraints_missing(self):
+ def test_avatar_constraints_missing(self) -> None:
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
be found.
"""
@@ -273,7 +278,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res)
@unittest.override_config({"max_avatar_size": 50})
- def test_avatar_constraints_file_size(self):
+ def test_avatar_constraints_file_size(self) -> None:
"""Tests that a file that's above the allowed file size is forbidden but one
that's below it is allowed.
"""
@@ -295,7 +300,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
- def test_avatar_constraint_mime_type(self):
+ def test_avatar_constraint_mime_type(self) -> None:
"""Tests that a file with an unauthorised MIME type is forbidden but one with
an authorised content type is allowed.
"""
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 23941abed8..8d4404eda1 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Any, Dict, Optional
from unittest.mock import Mock
import attr
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import RedirectException
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
class SamlHandlerTestCase(HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
- saml_config = {
+ saml_config: Dict[str, Any] = {
"sp_config": {"metadata": {}},
# Disable grandfathering.
"grandfathered_mxid_source_attribute": None,
@@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
self.handler = hs.get_saml_handler()
@@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
elif not has_xmlsec1:
skip = "Requires xmlsec1"
- def test_map_saml_response_to_user(self):
+ def test_map_saml_response_to_user(self) -> None:
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
# stub out the auth handler
@@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
- def test_map_saml_response_to_existing_user(self):
+ def test_map_saml_response_to_existing_user(self) -> None:
"""Existing users can log in with SAML account."""
store = self.hs.get_datastores().main
self.get_success(
@@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_saml_response_to_invalid_localpart(self):
+ def test_map_saml_response_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
# stub out the auth handler
@@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
auth_handler.complete_sso_login.assert_not_called()
- def test_map_saml_response_to_user_retries(self):
+ def test_map_saml_response_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
# stub out the auth handler and error renderer
@@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_map_saml_response_redirect(self):
+ def test_map_saml_response_redirect(self) -> None:
"""Test a mapping provider that raises a RedirectException"""
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
},
}
)
- def test_attribute_requirements(self):
+ def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the SAML response."""
# stub out the auth handler
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 171f4e97c8..f3741b3001 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -79,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
content: Optional[dict] = None,
access_token: Optional[str] = None,
parent_id: Optional[str] = None,
+ expected_response_code: int = 200,
) -> FakeChannel:
"""Helper function to send a relation pointing at `self.parent_id`
@@ -115,16 +116,50 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
content,
access_token=access_token,
)
+ self.assertEqual(expected_response_code, channel.code, channel.json_body)
return channel
+ def _get_related_events(self) -> List[str]:
+ """
+ Requests /relations on the parent ID and returns a list of event IDs.
+ """
+ # Request the relations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ return [ev["event_id"] for ev in channel.json_body["chunk"]]
+
+ def _get_bundled_aggregations(self) -> JsonDict:
+ """
+ Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists.
+ """
+ # Fetch the bundled aggregations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ return channel.json_body["unsigned"].get("m.relations", {})
+
+ def _get_aggregations(self) -> List[JsonDict]:
+ """Request /aggregations on the parent ID and includes the returned chunk."""
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ return channel.json_body["chunk"]
+
class RelationsTestCase(BaseRelationsTestCase):
def test_send_relation(self) -> None:
"""Tests that sending a relation works."""
-
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="π")
- self.assertEqual(200, channel.code, channel.json_body)
-
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -151,13 +186,13 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_deny_invalid_event(self) -> None:
"""Test that we deny relations on non-existant events"""
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.ANNOTATION,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
# Unless that event is referenced from another event!
self.get_success(
@@ -171,13 +206,12 @@ class RelationsTestCase(BaseRelationsTestCase):
desc="test_deny_invalid_event",
)
)
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.THREAD,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
)
- self.assertEqual(200, channel.code, channel.json_body)
def test_deny_invalid_room(self) -> None:
"""Test that we deny relations on non-existant events"""
@@ -187,18 +221,20 @@ class RelationsTestCase(BaseRelationsTestCase):
parent_id = res["event_id"]
# Attempt to send an annotation to that event.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+ self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ parent_id=parent_id,
+ key="A",
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
def test_deny_double_react(self) -> None:
"""Test that we deny relations on membership events"""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(400, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400
+ )
def test_deny_forked_thread(self) -> None:
"""It is invalid to start a thread off a thread."""
@@ -208,316 +244,24 @@ class RelationsTestCase(BaseRelationsTestCase):
content={"msgtype": "m.text", "body": "foo"},
parent_id=self.parent_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
parent_id = channel.json_body["event_id"]
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.THREAD,
"m.room.message",
content={"msgtype": "m.text", "body": "foo"},
parent_id=parent_id,
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
-
- def test_basic_paginate_relations(self) -> None:
- """Tests that calling pagination API correctly the latest relations."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
- first_annotation_id = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
- second_annotation_id = channel.json_body["event_id"]
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # We expect to get back a single pagination result, which is the latest
- # full relation event we sent above.
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
- self.assert_dict(
- {
- "event_id": second_annotation_id,
- "sender": self.user_id,
- "type": "m.reaction",
- },
- channel.json_body["chunk"][0],
- )
-
- # We also expect to get the original event (the id of which is self.parent_id)
- self.assertEqual(
- channel.json_body["original_event"]["event_id"], self.parent_id
- )
-
- # Make sure next_batch has something in it that looks like it could be a
- # valid token.
- self.assertIsInstance(
- channel.json_body.get("next_batch"), str, channel.json_body
- )
-
- # Request the relations again, but with a different direction.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations"
- f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # We expect to get back a single pagination result, which is the earliest
- # full relation event we sent above.
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
- self.assert_dict(
- {
- "event_id": first_annotation_id,
- "sender": self.user_id,
- "type": "m.reaction",
- },
- channel.json_body["chunk"][0],
- )
-
- def test_repeated_paginate_relations(self) -> None:
- """Test that if we paginate using a limit and tokens then we get the
- expected events.
- """
-
- expected_event_ids = []
- for idx in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
- )
- self.assertEqual(200, channel.code, channel.json_body)
- expected_event_ids.append(channel.json_body["event_id"])
-
- prev_token = ""
- found_event_ids: List[str] = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
- def test_pagination_from_sync_and_messages(self) -> None:
- """Pagination tokens from /sync and /messages can be used to paginate /relations."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
- self.assertEqual(200, channel.code, channel.json_body)
- annotation_id = channel.json_body["event_id"]
- # Send an event after the relation events.
- self.helper.send(self.room, body="Latest event", tok=self.user_token)
-
- # Request /sync, limiting it such that only the latest event is returned
- # (and not the relation).
- filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
- channel = self.make_request(
- "GET", f"/sync?filter={filter}", access_token=self.user_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- sync_prev_batch = room_timeline["prev_batch"]
- self.assertIsNotNone(sync_prev_batch)
- # Ensure the relation event is not in the batch returned from /sync.
- self.assertNotIn(
- annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
- )
-
- # Request /messages, limiting it such that only the latest event is
- # returned (and not the relation).
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b&limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- messages_end = channel.json_body["end"]
- self.assertIsNotNone(messages_end)
- # Ensure the relation event is not in the chunk returned from /messages.
- self.assertNotIn(
- annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
- )
-
- # Request /relations with the pagination tokens received from both the
- # /sync and /messages responses above, in turn.
- #
- # This is a tiny bit silly since the client wouldn't know the parent ID
- # from the requests above; consider the parent ID to be known from a
- # previous /sync.
- for from_token in (sync_prev_batch, messages_end):
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # The relation should be in the returned chunk.
- self.assertIn(
- annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
- )
-
- def test_aggregation_pagination_groups(self) -> None:
- """Test that we can paginate annotation groups correctly."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- sent_groups = {"π": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
- for key in itertools.chain.from_iterable(
- itertools.repeat(key, num) for key, num in sent_groups.items()
- ):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key=key,
- access_token=access_tokens[idx],
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- idx += 1
- idx %= len(access_tokens)
-
- prev_token: Optional[str] = None
- found_groups: Dict[str, int] = {}
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- for groups in channel.json_body["chunk"]:
- # We only expect reactions
- self.assertEqual(groups["type"], "m.reaction", channel.json_body)
-
- # We should only see each key once
- self.assertNotIn(groups["key"], found_groups, channel.json_body)
-
- found_groups[groups["key"]] = groups["count"]
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- self.assertEqual(sent_groups, found_groups)
-
- def test_aggregation_pagination_within_group(self) -> None:
- """Test that we can paginate within an annotation group."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- expected_event_ids = []
- for _ in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="π",
- access_token=access_tokens[idx],
- )
- self.assertEqual(200, channel.code, channel.json_body)
- expected_event_ids.append(channel.json_body["event_id"])
-
- idx += 1
-
- # Also send a different type of reaction so that we test we don't see it
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- prev_token = ""
- found_event_ids: List[str] = []
- encoded_key = urllib.parse.quote_plus("π".encode())
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}"
- f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- f"/m.reaction/{encoded_key}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
def test_aggregation(self) -> None:
"""Test that annotations get correctly aggregated."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
channel = self.make_request(
"GET",
@@ -558,30 +302,21 @@ class RelationsTestCase(BaseRelationsTestCase):
See test_edit for a similar test for edits.
"""
# Setup by sending a variety of relations.
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
reply_1 = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
reply_2 = channel.json_body["event_id"]
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.THREAD, "m.room.test")
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
thread_2 = channel.json_body["event_id"]
def assert_bundle(event_json: JsonDict) -> None:
@@ -693,14 +428,12 @@ class RelationsTestCase(BaseRelationsTestCase):
when directly requested.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
annotation_id = channel.json_body["event_id"]
# Annotate the annotation.
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
)
- self.assertEqual(200, channel.code, channel.json_body)
channel = self.make_request(
"GET",
@@ -713,14 +446,12 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_aggregation_get_event_for_thread(self) -> None:
"""Test that threads get bundled aggregations included when directly requested."""
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
thread_id = channel.json_body["event_id"]
# Annotate the annotation.
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
)
- self.assertEqual(200, channel.code, channel.json_body)
channel = self.make_request(
"GET",
@@ -877,8 +608,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
def assert_bundle(event_json: JsonDict) -> None:
@@ -954,7 +683,7 @@ class RelationsTestCase(BaseRelationsTestCase):
shouldn't be allowed, are correctly handled.
"""
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
@@ -963,7 +692,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
@@ -971,11 +699,9 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message.WRONG_TYPE",
content={
@@ -984,7 +710,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
channel = self.make_request(
"GET",
@@ -1015,7 +740,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "A reply!"},
)
- self.assertEqual(200, channel.code, channel.json_body)
reply = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1025,8 +749,6 @@ class RelationsTestCase(BaseRelationsTestCase):
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=reply,
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1071,7 +793,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "A threaded reply!"},
)
- self.assertEqual(200, channel.code, channel.json_body)
threaded_event_id = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1081,7 +802,6 @@ class RelationsTestCase(BaseRelationsTestCase):
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=threaded_event_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
# Fetch the thread root, to get the bundled aggregation for the thread.
channel = self.make_request(
@@ -1113,7 +833,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": new_body,
},
)
- self.assertEqual(200, channel.code, channel.json_body)
edit_event_id = channel.json_body["event_id"]
# Edit the edit event.
@@ -1127,7 +846,6 @@ class RelationsTestCase(BaseRelationsTestCase):
},
parent_id=edit_event_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
# Request the original event.
channel = self.make_request(
@@ -1154,7 +872,6 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_unknown_relations(self) -> None:
"""Unknown relations should be accepted."""
channel = self._send_relation("m.relation.test", "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1208,15 +925,12 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="π")
- self.assertEqual(200, channel.code, channel.json_body)
annotation_event_id_good = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
- self.assertEqual(200, channel.code, channel.json_body)
annotation_event_id_bad = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
thread_event_id = channel.json_body["event_id"]
# Clean-up the table as if the inserts did not happen during event creation.
@@ -1296,71 +1010,312 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertIn("m.relations", parent_event["unsigned"])
-class RelationRedactionTestCase(BaseRelationsTestCase):
- """
- Test the behaviour of relations when the parent or child event is redacted.
+class RelationPaginationTestCase(BaseRelationsTestCase):
+ def test_basic_paginate_relations(self) -> None:
+ """Tests that calling pagination API correctly the latest relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ first_annotation_id = channel.json_body["event_id"]
- The behaviour of each relation type is subtly different which causes the tests
- to be a bit repetitive, they follow a naming scheme of:
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+ second_annotation_id = channel.json_body["event_id"]
- test_redact_(relation|parent)_{relation_type}
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
- The first bit of "relation" means that the event with the relation defined
- on it (the child event) is to be redacted. A "parent" means that the target
- of the relation (the parent event) is to be redacted.
+ # We expect to get back a single pagination result, which is the latest
+ # full relation event we sent above.
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {
+ "event_id": second_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
+ channel.json_body["chunk"][0],
+ )
- The relation_type describes which type of relation is under test (i.e. it is
- related to the value of rel_type in the event content).
- """
+ # We also expect to get the original event (the id of which is self.parent_id)
+ self.assertEqual(
+ channel.json_body["original_event"]["event_id"], self.parent_id
+ )
- def _redact(self, event_id: str) -> None:
+ # Make sure next_batch has something in it that looks like it could be a
+ # valid token.
+ self.assertIsInstance(
+ channel.json_body.get("next_batch"), str, channel.json_body
+ )
+
+ # Request the relations again, but with a different direction.
channel = self.make_request(
- "POST",
- f"/_matrix/client/r0/rooms/{self.room}/redact/{event_id}",
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations"
+ f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
access_token=self.user_token,
- content={},
)
self.assertEqual(200, channel.code, channel.json_body)
- def _make_relation_requests(self) -> Tuple[List[str], JsonDict]:
- """
- Makes requests and ensures they result in a 200 response, returns a
- tuple of results:
+ # We expect to get back a single pagination result, which is the earliest
+ # full relation event we sent above.
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {
+ "event_id": first_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
+ channel.json_body["chunk"][0],
+ )
- 1. `/relations` -> Returns a list of event IDs.
- 2. `/event` -> Returns the response's m.relations field (from unsigned),
- if it exists.
+ def test_repeated_paginate_relations(self) -> None:
+ """Test that if we paginate using a limit and tokens then we get the
+ expected events.
"""
- # Request the relations of the event.
+ expected_event_ids = []
+ for idx in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
+ )
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ prev_token = ""
+ found_event_ids: List[str] = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEqual(found_event_ids, expected_event_ids)
+
+ def test_pagination_from_sync_and_messages(self) -> None:
+ """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+ annotation_id = channel.json_body["event_id"]
+ # Send an event after the relation events.
+ self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+ # Request /sync, limiting it such that only the latest event is returned
+ # (and not the relation).
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
- access_token=self.user_token,
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ sync_prev_batch = room_timeline["prev_batch"]
+ self.assertIsNotNone(sync_prev_batch)
+ # Ensure the relation event is not in the batch returned from /sync.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
)
- self.assertEquals(200, channel.code, channel.json_body)
- event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
- # Fetch the bundled aggregations of the event.
+ # Request /messages, limiting it such that only the latest event is
+ # returned (and not the relation).
channel = self.make_request(
"GET",
- f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
+ f"/rooms/{self.room}/messages?dir=b&limit=1",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
- bundled_relations = channel.json_body["unsigned"].get("m.relations", {})
+ self.assertEqual(200, channel.code, channel.json_body)
+ messages_end = channel.json_body["end"]
+ self.assertIsNotNone(messages_end)
+ # Ensure the relation event is not in the chunk returned from /messages.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
+ # Request /relations with the pagination tokens received from both the
+ # /sync and /messages responses above, in turn.
+ #
+ # This is a tiny bit silly since the client wouldn't know the parent ID
+ # from the requests above; consider the parent ID to be known from a
+ # previous /sync.
+ for from_token in (sync_prev_batch, messages_end):
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
- return event_ids, bundled_relations
+ # The relation should be in the returned chunk.
+ self.assertIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
- def _get_aggregations(self) -> List[JsonDict]:
- """Request /aggregations on the parent ID and includes the returned chunk."""
+ def test_aggregation_pagination_groups(self) -> None:
+ """Test that we can paginate annotation groups correctly."""
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ sent_groups = {"π": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
+ for key in itertools.chain.from_iterable(
+ itertools.repeat(key, num) for key, num in sent_groups.items()
+ ):
+ self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key=key,
+ access_token=access_tokens[idx],
+ )
+
+ idx += 1
+ idx %= len(access_tokens)
+
+ prev_token: Optional[str] = None
+ found_groups: Dict[str, int] = {}
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ for groups in channel.json_body["chunk"]:
+ # We only expect reactions
+ self.assertEqual(groups["type"], "m.reaction", channel.json_body)
+
+ # We should only see each key once
+ self.assertNotIn(groups["key"], found_groups, channel.json_body)
+
+ found_groups[groups["key"]] = groups["count"]
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ self.assertEqual(sent_groups, found_groups)
+
+ def test_aggregation_pagination_within_group(self) -> None:
+ """Test that we can paginate within an annotation group."""
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ expected_event_ids = []
+ for _ in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="π",
+ access_token=access_tokens[idx],
+ )
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ idx += 1
+
+ # Also send a different type of reaction so that we test we don't see it
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+
+ prev_token = ""
+ found_event_ids: List[str] = []
+ encoded_key = urllib.parse.quote_plus("π".encode())
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}"
+ f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+ f"/m.reaction/{encoded_key}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEqual(found_event_ids, expected_event_ids)
+
+
+class RelationRedactionTestCase(BaseRelationsTestCase):
+ """
+ Test the behaviour of relations when the parent or child event is redacted.
+
+ The behaviour of each relation type is subtly different which causes the tests
+ to be a bit repetitive, they follow a naming scheme of:
+
+ test_redact_(relation|parent)_{relation_type}
+
+ The first bit of "relation" means that the event with the relation defined
+ on it (the child event) is to be redacted. A "parent" means that the target
+ of the relation (the parent event) is to be redacted.
+
+ The relation_type describes which type of relation is under test (i.e. it is
+ related to the value of rel_type in the event content).
+ """
+
+ def _redact(self, event_id: str) -> None:
channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
+ "POST",
+ f"/_matrix/client/r0/rooms/{self.room}/redact/{event_id}",
access_token=self.user_token,
+ content={},
)
self.assertEqual(200, channel.code, channel.json_body)
- return channel.json_body["chunk"]
def test_redact_relation_annotation(self) -> None:
"""
@@ -1371,17 +1326,16 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
the response to relations.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
- self.assertEqual(200, channel.code, channel.json_body)
unredacted_event_id = channel.json_body["event_id"]
# Both relations should exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
self.assertEquals(
relations["m.annotation"],
@@ -1396,7 +1350,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(to_redact_event_id)
# The unredacted relation should still exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [unredacted_event_id])
self.assertEquals(
relations["m.annotation"],
@@ -1419,7 +1374,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
EventTypes.Message,
content={"body": "reply 1", "msgtype": "m.text"},
)
- self.assertEqual(200, channel.code, channel.json_body)
unredacted_event_id = channel.json_body["event_id"]
# Note that the *last* event in the thread is redacted, as that gets
@@ -1429,11 +1383,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
EventTypes.Message,
content={"body": "reply 2", "msgtype": "m.text"},
)
- self.assertEqual(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]
# Both relations exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
self.assertDictContainsSubset(
{
@@ -1452,7 +1406,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(to_redact_event_id)
# The unredacted relation should still exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [unredacted_event_id])
self.assertDictContainsSubset(
{
@@ -1472,7 +1427,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
is redacted.
"""
# Add a relation
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
parent_id=self.parent_id,
@@ -1482,10 +1437,10 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
# Check the relation is returned
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEqual(len(event_ids), 1)
self.assertIn(RelationTypes.REPLACE, relations)
@@ -1493,7 +1448,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(self.parent_id)
# The relations are not returned.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEqual(len(event_ids), 0)
self.assertEqual(relations, {})
@@ -1503,11 +1459,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
"""
# Add a relation
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="π")
- self.assertEqual(200, channel.code, channel.json_body)
related_event_id = channel.json_body["event_id"]
# The relations should exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEqual(len(event_ids), 1)
self.assertIn(RelationTypes.ANNOTATION, relations)
@@ -1519,7 +1475,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(self.parent_id)
# The relations are returned.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [related_event_id])
self.assertEquals(
relations["m.annotation"],
@@ -1540,14 +1497,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
EventTypes.Message,
content={"body": "reply 1", "msgtype": "m.text"},
)
- self.assertEqual(200, channel.code, channel.json_body)
related_event_id = channel.json_body["event_id"]
# Redact one of the reactions.
self._redact(self.parent_id)
# The unredacted relation should still exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEquals(len(event_ids), 1)
self.assertDictContainsSubset(
{
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index 3fb37a2a59..62e308814d 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import (
_get_html_media_encodings,
decode_body,
parse_html_to_open_graph,
- rebase_url,
summarize_paragraphs,
)
@@ -161,7 +160,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -177,7 +176,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(
og,
@@ -218,7 +217,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -232,7 +231,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -247,7 +246,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
@@ -262,7 +261,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -290,7 +289,7 @@ class CalcOgTestCase(unittest.TestCase):
<head><title>Foo</title></head><body>Some text.</body></html>
""".strip()
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding(self) -> None:
@@ -304,7 +303,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding2(self) -> None:
@@ -319,7 +318,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ΓΏΓΏ Foo", "og:description": "Some text."})
def test_windows_1252(self) -> None:
@@ -333,7 +332,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Γ³", "og:description": "Some text."})
@@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase):
'text/html; charset="invalid"',
)
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
-
-
-class RebaseUrlTestCase(unittest.TestCase):
- def test_relative(self) -> None:
- """Relative URLs should be resolved based on the context of the base URL."""
- self.assertEqual(
- rebase_url("subpage", "https://example.com/foo/"),
- "https://example.com/foo/subpage",
- )
- self.assertEqual(
- rebase_url("sibling", "https://example.com/foo"),
- "https://example.com/sibling",
- )
- self.assertEqual(
- rebase_url("/bar", "https://example.com/foo/"),
- "https://example.com/bar",
- )
-
- def test_absolute(self) -> None:
- """Absolute URLs should not be modified."""
- self.assertEqual(
- rebase_url("https://alice.com/a/", "https://example.com/foo/"),
- "https://alice.com/a/",
- )
-
- def test_data(self) -> None:
- """Data URLs should not be modified."""
- self.assertEqual(
- rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
- "data:,Hello%2C%20World%21",
- )
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 272cd35402..72bf5b3d31 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -47,9 +47,18 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
expected_ignorer_user_ids,
)
+ def assert_ignored(
+ self, ignorer_user_id: str, expected_ignored_user_ids: Set[str]
+ ) -> None:
+ self.assertEqual(
+ self.get_success(self.store.ignored_users(ignorer_user_id)),
+ expected_ignored_user_ids,
+ )
+
def test_ignoring_users(self):
"""Basic adding/removing of users from the ignore list."""
self._update_ignore_list("@other:test", "@another:remote")
+ self.assert_ignored(self.user, {"@other:test", "@another:remote"})
# Check a user which no one ignores.
self.assert_ignorers("@user:test", set())
@@ -62,6 +71,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# Add one user, remove one user, and leave one user.
self._update_ignore_list("@foo:test", "@another:remote")
+ self.assert_ignored(self.user, {"@foo:test", "@another:remote"})
# Check the removed user.
self.assert_ignorers("@other:test", set())
@@ -76,20 +86,24 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
"""Ensure that caching works properly between different users."""
# The first user ignores a user.
self._update_ignore_list("@other:test")
+ self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user})
# The second user ignores them.
self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
+ self.assert_ignored("@second:test", {"@other:test"})
self.assert_ignorers("@other:test", {self.user, "@second:test"})
# The first user un-ignores them.
self._update_ignore_list()
+ self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", {"@second:test"})
def test_invalid_data(self):
"""Invalid data ends up clearing out the ignored users list."""
# Add some data and ensure it is there.
self._update_ignore_list("@other:test")
+ self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user})
# No ignored_users key.
@@ -102,10 +116,12 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
)
# No one ignores the user now.
+ self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", set())
# Add some data and ensure it is there.
self._update_ignore_list("@other:test")
+ self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user})
# Invalid data.
@@ -118,4 +134,5 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
)
# No one ignores the user now.
+ self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", set())
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 8597867563..a40fc20ef9 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -12,7 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.database import make_tuple_comparison_clause
+from typing import Callable, Tuple
+from unittest.mock import Mock, call
+
+from twisted.internet import defer
+from twisted.internet.defer import CancelledError, Deferred
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingTransaction,
+ make_tuple_comparison_clause,
+)
+from synapse.util import Clock
from tests import unittest
@@ -22,3 +35,150 @@ class TupleComparisonClauseTestCase(unittest.TestCase):
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2])
+
+
+class CallbacksTestCase(unittest.HomeserverTestCase):
+ """Tests for transaction callbacks."""
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+
+ def _run_interaction(
+ self, func: Callable[[LoggingTransaction], object]
+ ) -> Tuple[Mock, Mock]:
+ """Run the given function in a database transaction, with callbacks registered.
+
+ Args:
+ func: The function to be run in a transaction. The transaction will be
+ retried if `func` raises an `OperationalError`.
+
+ Returns:
+ Two mocks, which were registered as an `after_callback` and an
+ `exception_callback` respectively, on every transaction attempt.
+ """
+ after_callback = Mock()
+ exception_callback = Mock()
+
+ def _test_txn(txn: LoggingTransaction) -> None:
+ txn.call_after(after_callback, 123, 456, extra=789)
+ txn.call_on_exception(exception_callback, 987, 654, extra=321)
+ func(txn)
+
+ try:
+ self.get_success_or_raise(
+ self.db_pool.runInteraction("test_transaction", _test_txn)
+ )
+ except Exception:
+ pass
+
+ return after_callback, exception_callback
+
+ def test_after_callback(self) -> None:
+ """Test that the after callback is called when a transaction succeeds."""
+ after_callback, exception_callback = self._run_interaction(lambda txn: None)
+
+ after_callback.assert_called_once_with(123, 456, extra=789)
+ exception_callback.assert_not_called()
+
+ def test_exception_callback(self) -> None:
+ """Test that the exception callback is called when a transaction fails."""
+ _test_txn = Mock(side_effect=ZeroDivisionError)
+ after_callback, exception_callback = self._run_interaction(_test_txn)
+
+ after_callback.assert_not_called()
+ exception_callback.assert_called_once_with(987, 654, extra=321)
+
+ def test_failed_retry(self) -> None:
+ """Test that the exception callback is called for every failed attempt."""
+ # Always raise an `OperationalError`.
+ _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
+ after_callback, exception_callback = self._run_interaction(_test_txn)
+
+ after_callback.assert_not_called()
+ exception_callback.assert_has_calls(
+ [
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ ]
+ )
+ self.assertEqual(exception_callback.call_count, 6) # no additional calls
+
+ def test_successful_retry(self) -> None:
+ """Test callbacks for a failed transaction followed by a successful attempt."""
+ # Raise an `OperationalError` on the first attempt only.
+ _test_txn = Mock(
+ side_effect=[self.db_pool.engine.module.OperationalError, None]
+ )
+ after_callback, exception_callback = self._run_interaction(_test_txn)
+
+ # Calling both `after_callback`s when the first attempt failed is rather
+ # surprising (#12184). Let's document the behaviour in a test.
+ after_callback.assert_has_calls(
+ [
+ call(123, 456, extra=789),
+ call(123, 456, extra=789),
+ ]
+ )
+ self.assertEqual(after_callback.call_count, 2) # no additional calls
+ exception_callback.assert_not_called()
+
+
+class CancellationTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+
+ def test_after_callback(self) -> None:
+ """Test that the after callback is called when a transaction succeeds."""
+ d: "Deferred[None]"
+ after_callback = Mock()
+ exception_callback = Mock()
+
+ def _test_txn(txn: LoggingTransaction) -> None:
+ txn.call_after(after_callback, 123, 456, extra=789)
+ txn.call_on_exception(exception_callback, 987, 654, extra=321)
+ d.cancel()
+
+ d = defer.ensureDeferred(
+ self.db_pool.runInteraction("test_transaction", _test_txn)
+ )
+ self.get_failure(d, CancelledError)
+
+ after_callback.assert_called_once_with(123, 456, extra=789)
+ exception_callback.assert_not_called()
+
+ def test_exception_callback(self) -> None:
+ """Test that the exception callback is called when a transaction fails."""
+ d: "Deferred[None]"
+ after_callback = Mock()
+ exception_callback = Mock()
+
+ def _test_txn(txn: LoggingTransaction) -> None:
+ txn.call_after(after_callback, 123, 456, extra=789)
+ txn.call_on_exception(exception_callback, 987, 654, extra=321)
+ d.cancel()
+ # Simulate a retryable failure on every attempt.
+ raise self.db_pool.engine.module.OperationalError()
+
+ d = defer.ensureDeferred(
+ self.db_pool.runInteraction("test_transaction", _test_txn)
+ )
+ self.get_failure(d, CancelledError)
+
+ after_callback.assert_not_called()
+ exception_callback.assert_has_calls(
+ [
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ ]
+ )
+ self.assertEqual(exception_callback.call_count, 6) # no additional calls
|