diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 15fce165b6..3c635e3dcb 100644
--- a/tests/appservice/test_api.py
+++ b/tests/appservice/test_api.py
@@ -11,18 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List, Mapping, Sequence, Union
+from typing import Any, List, Mapping, Optional, Sequence, Union
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.errors import HttpResponseException
from synapse.appservice import ApplicationService
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
+from tests.unittest import override_config
PROTOCOL = "myproto"
TOKEN = "myastoken"
@@ -40,7 +40,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
hs_token=TOKEN,
)
- def test_query_3pe_authenticates_token(self) -> None:
+ def test_query_3pe_authenticates_token_via_header(self) -> None:
"""
Tests that 3pe queries to the appservice are authenticated
with the appservice's token.
@@ -75,12 +75,16 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
args: Mapping[Any, Any],
headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
) -> List[JsonDict]:
- # Ensure the access token is passed as both a header and query arg.
- if not headers.get("Authorization") or not args.get(b"access_token"):
+ # Ensure the access token is passed as a header.
+ if not headers or not headers.get("Authorization"):
raise RuntimeError("Access token not provided")
+ # ... and not as a query param
+ if b"access_token" in args:
+ raise RuntimeError(
+ "Access token should not be passed as a query param."
+ )
self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
- self.assertEqual(args.get(b"access_token"), TOKEN)
self.request_url = url
if url == URL_USER:
return SUCCESS_RESULT_USER
@@ -107,10 +111,13 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.request_url, URL_LOCATION)
self.assertEqual(result, SUCCESS_RESULT_LOCATION)
- def test_fallback(self) -> None:
+ @override_config({"use_appservice_legacy_authorization": True})
+ def test_query_3pe_authenticates_token_via_param(self) -> None:
"""
- Tests that the fallback to legacy URLs works.
+ Tests that 3pe queries to the appservice are authenticated
+ with the appservice's token.
"""
+
SUCCESS_RESULT_USER = [
{
"protocol": PROTOCOL,
@@ -120,30 +127,41 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
},
}
]
+ SUCCESS_RESULT_LOCATION = [
+ {
+ "protocol": PROTOCOL,
+ "alias": "#a:room",
+ "fields": {
+ "more": "fields",
+ },
+ }
+ ]
URL_USER = f"{URL}/_matrix/app/v1/thirdparty/user/{PROTOCOL}"
- FALLBACK_URL_USER = f"{URL}/_matrix/app/unstable/thirdparty/user/{PROTOCOL}"
+ URL_LOCATION = f"{URL}/_matrix/app/v1/thirdparty/location/{PROTOCOL}"
self.request_url = None
- self.v1_seen = False
async def get_json(
url: str,
args: Mapping[Any, Any],
- headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
+ headers: Optional[
+ Mapping[Union[str, bytes], Sequence[Union[str, bytes]]]
+ ] = None,
) -> List[JsonDict]:
- # Ensure the access token is passed as both a header and query arg.
- if not headers.get("Authorization") or not args.get(b"access_token"):
- raise RuntimeError("Access token not provided")
+ # Ensure the access token is passed as a both a query param and in the headers.
+ if not args.get(b"access_token"):
+ raise RuntimeError("Access token should be provided in query params.")
+ if not headers or not headers.get("Authorization"):
+ raise RuntimeError("Access token should be provided in auth headers.")
- self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
self.assertEqual(args.get(b"access_token"), TOKEN)
+ self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
self.request_url = url
if url == URL_USER:
- self.v1_seen = True
- raise HttpResponseException(404, "NOT_FOUND", b"NOT_FOUND")
- elif url == FALLBACK_URL_USER:
return SUCCESS_RESULT_USER
+ elif url == URL_LOCATION:
+ return SUCCESS_RESULT_LOCATION
else:
raise RuntimeError(
"URL provided was invalid. This should never be seen."
@@ -155,9 +173,15 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
result = self.get_success(
self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]})
)
- self.assertTrue(self.v1_seen)
- self.assertEqual(self.request_url, FALLBACK_URL_USER)
+ self.assertEqual(self.request_url, URL_USER)
self.assertEqual(result, SUCCESS_RESULT_USER)
+ result = self.get_success(
+ self.api.query_3pe(
+ self.service, "location", PROTOCOL, {b"some": [b"field"]}
+ )
+ )
+ self.assertEqual(self.request_url, URL_LOCATION)
+ self.assertEqual(result, SUCCESS_RESULT_LOCATION)
def test_claim_keys(self) -> None:
"""
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 196ceb0b82..ec2f5d30be 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -179,6 +179,16 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("http://my.server/me.png", avatar_url)
+ def test_get_profile_empty_displayname(self) -> None:
+ self.get_success(self.store.set_profile_displayname(self.frank, None))
+ self.get_success(
+ self.store.set_profile_avatar_url(self.frank, "http://my.server/me.png")
+ )
+
+ profile = self.get_success(self.handler.get_profile(self.frank.to_string()))
+
+ self.assertEqual("http://my.server/me.png", profile["avatar_url"])
+
def test_set_my_avatar(self) -> None:
self.get_success(
self.handler.set_avatar_url(
diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py
new file mode 100644
index 0000000000..73e548726c
--- /dev/null
+++ b/tests/handlers/test_worker_lock.py
@@ -0,0 +1,74 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+
+class WorkerLockTestCase(unittest.HomeserverTestCase):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self.worker_lock_handler = self.hs.get_worker_locks_handler()
+
+ def test_wait_for_lock_locally(self) -> None:
+ """Test waiting for a lock on a single worker"""
+
+ lock1 = self.worker_lock_handler.acquire_lock("name", "key")
+ self.get_success(lock1.__aenter__())
+
+ lock2 = self.worker_lock_handler.acquire_lock("name", "key")
+ d2 = defer.ensureDeferred(lock2.__aenter__())
+ self.assertNoResult(d2)
+
+ self.get_success(lock1.__aexit__(None, None, None))
+
+ self.get_success(d2)
+ self.get_success(lock2.__aexit__(None, None, None))
+
+
+class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self.main_worker_lock_handler = self.hs.get_worker_locks_handler()
+
+ def test_wait_for_lock_worker(self) -> None:
+ """Test waiting for a lock on another worker"""
+
+ worker = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "redis": {"enabled": True},
+ },
+ )
+ worker_lock_handler = worker.get_worker_locks_handler()
+
+ lock1 = self.main_worker_lock_handler.acquire_lock("name", "key")
+ self.get_success(lock1.__aenter__())
+
+ lock2 = worker_lock_handler.acquire_lock("name", "key")
+ d2 = defer.ensureDeferred(lock2.__aenter__())
+ self.assertNoResult(d2)
+
+ self.get_success(lock1.__aexit__(None, None, None))
+
+ self.get_success(d2)
+ self.get_success(lock2.__aexit__(None, None, None))
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 1e06f86071..829b9df83d 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -409,12 +409,12 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
)
- # Room mentions from those without power should not notify.
+ # The edit should not cause a notification.
self.assertFalse(
self._create_and_process(
bulk_evaluator,
{
- "body": self.alice,
+ "body": "Test message",
"m.relates_to": {
"rel_type": RelationTypes.REPLACE,
"event_id": event.event_id,
@@ -422,3 +422,20 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
+
+ # An edit which is a mention will cause a notification.
+ self.assertTrue(
+ self._create_and_process(
+ bulk_evaluator,
+ {
+ "body": "Test message",
+ "m.relates_to": {
+ "rel_type": RelationTypes.REPLACE,
+ "event_id": event.event_id,
+ },
+ "m.mentions": {
+ "user_ids": [self.alice],
+ },
+ },
+ )
+ )
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index b7d420cfec..3cf29c10ea 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -379,4 +379,141 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
access_token=token,
shorthand=False,
)
- self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.code, 401)
+
+ @unittest.override_config(
+ {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+ )
+ def test_msc3814_dehydrated_device_delete_works(self) -> None:
+ user = self.register_user("mikey", "pass")
+ token = self.login(user, "pass", device_id="device1")
+ content: JsonDict = {
+ "device_data": {
+ "algorithm": "m.dehydration.v1.olm",
+ },
+ "device_id": "device2",
+ "initial_device_display_name": "foo bar",
+ "device_keys": {
+ "user_id": "@mikey:test",
+ "device_id": "device2",
+ "valid_until_ts": "80",
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ ],
+ "keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ "signatures": {
+ "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
+ },
+ },
+ }
+ channel = self.make_request(
+ "PUT",
+ "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+ content=content,
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+ device_id = channel.json_body.get("device_id")
+ assert device_id is not None
+ self.assertIsInstance(device_id, str)
+ self.assertEqual("device2", device_id)
+
+ # ensure that keys were uploaded and available
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ user: ["device2"],
+ },
+ },
+ token,
+ )
+ self.assertEqual(
+ channel.json_body["device_keys"][user]["device2"]["keys"],
+ {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ )
+
+ # delete the dehydrated device
+ channel = self.make_request(
+ "DELETE",
+ "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # ensure that keys are no longer available for deleted device
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ user: ["device2"],
+ },
+ },
+ token,
+ )
+ self.assertEqual(channel.json_body["device_keys"], {"@mikey:test": {}})
+
+ # check that an old device is deleted when user PUTs a new device
+ # First, create a device
+ content["device_id"] = "device3"
+ content["device_keys"]["device_id"] = "device3"
+ channel = self.make_request(
+ "PUT",
+ "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+ content=content,
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+ device_id = channel.json_body.get("device_id")
+ assert device_id is not None
+ self.assertIsInstance(device_id, str)
+ self.assertEqual("device3", device_id)
+
+ # create a second device without deleting first device
+ content["device_id"] = "device4"
+ content["device_keys"]["device_id"] = "device4"
+ channel = self.make_request(
+ "PUT",
+ "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+ content=content,
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+ device_id = channel.json_body.get("device_id")
+ assert device_id is not None
+ self.assertIsInstance(device_id, str)
+ self.assertEqual("device4", device_id)
+
+ # check that the second device that was created is what is returned when we GET
+ channel = self.make_request(
+ "GET",
+ "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+ returned_device_id = channel.json_body["device_id"]
+ self.assertEqual(returned_device_id, "device4")
+
+ # and that if we query the keys for the first device they are not there
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ user: ["device3"],
+ },
+ },
+ token,
+ )
+ self.assertEqual(channel.json_body["device_keys"], {"@mikey:test": {}})
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 27c93ad761..ecae092b47 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -68,6 +68,18 @@ class ProfileTestCase(unittest.HomeserverTestCase):
res = self._get_displayname()
self.assertEqual(res, "test")
+ def test_set_displayname_with_extra_spaces(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.owner,),
+ content={"displayname": " test "},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ res = self._get_displayname()
+ self.assertEqual(res, "test")
+
def test_set_displayname_noauth(self) -> None:
channel = self.make_request(
"PUT",
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 6028886bd6..180b635ea6 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -13,10 +13,12 @@
# limitations under the License.
from typing import List, Optional
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, RelationTypes
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.rest import admin
from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
@@ -569,50 +571,81 @@ class RedactionsTestCase(HomeserverTestCase):
self.assertIn("body", event_dict["content"], event_dict)
self.assertEqual("I'm in a thread!", event_dict["content"]["body"])
- def test_content_redaction(self) -> None:
- """MSC2174 moved the redacts property to the content."""
+ @parameterized.expand(
+ [
+ # Tuples of:
+ # Room version
+ # Boolean: True if the redaction event content should include the event ID.
+ # Boolean: true if the resulting redaction event is expected to include the
+ # event ID in the content.
+ (RoomVersions.V10, False, False),
+ (RoomVersions.V11, True, True),
+ (RoomVersions.V11, False, True),
+ ]
+ )
+ def test_redaction_content(
+ self, room_version: RoomVersion, include_content: bool, expect_content: bool
+ ) -> None:
+ """
+ Room version 11 moved the redacts property to the content.
+
+ Ensure that the event gets created properly and that the Client-Server
+ API servers the proper backwards-compatible version.
+ """
# Create a room with the newer room version.
room_id = self.helper.create_room_as(
self.mod_user_id,
tok=self.mod_access_token,
- room_version=RoomVersions.V11.identifier,
+ room_version=room_version.identifier,
)
# Create an event.
b = self.helper.send(room_id=room_id, tok=self.mod_access_token)
event_id = b["event_id"]
- # Attempt to redact it with a bogus event ID.
- self._redact_event(
+ # Ensure the event ID in the URL and the content must match.
+ if include_content:
+ self._redact_event(
+ self.mod_access_token,
+ room_id,
+ event_id,
+ expect_code=400,
+ content={"redacts": "foo"},
+ )
+
+ # Redact it for real.
+ result = self._redact_event(
self.mod_access_token,
room_id,
event_id,
- expect_code=400,
- content={"redacts": "foo"},
+ content={"redacts": event_id} if include_content else {},
)
-
- # Redact it for real.
- self._redact_event(self.mod_access_token, room_id, event_id)
+ redaction_event_id = result["event_id"]
# Sync the room, to get the id of the create event
timeline = self._sync_room_timeline(self.mod_access_token, room_id)
redact_event = timeline[-1]
self.assertEqual(redact_event["type"], EventTypes.Redaction)
- # The redacts key should be in the content.
+ # The redacts key should be in the content and the redacts keys.
self.assertEquals(redact_event["content"]["redacts"], event_id)
-
- # It should also be copied as the top-level redacts field for backwards
- # compatibility.
self.assertEquals(redact_event["redacts"], event_id)
# But it isn't actually part of the event.
def get_event(txn: LoggingTransaction) -> JsonDict:
return db_to_json(
- main_datastore._fetch_event_rows(txn, [event_id])[event_id].json
+ main_datastore._fetch_event_rows(txn, [redaction_event_id])[
+ redaction_event_id
+ ].json
)
main_datastore = self.hs.get_datastores().main
event_json = self.get_success(
main_datastore.db_pool.runInteraction("get_event", get_event)
)
- self.assertNotIn("redacts", event_json)
+ self.assertEquals(event_json["type"], EventTypes.Redaction)
+ if expect_content:
+ self.assertNotIn("redacts", event_json)
+ self.assertEquals(event_json["content"]["redacts"], event_id)
+ else:
+ self.assertEquals(event_json["redacts"], event_id)
+ self.assertNotIn("redacts", event_json["content"])
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index d013e75d55..4f6347be15 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -711,7 +711,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(30, channel.resource_usage.db_txn_count)
+ self.assertEqual(32, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -724,7 +724,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(32, channel.resource_usage.db_txn_count)
+ self.assertEqual(34, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index ad454f6dd8..383da83dfb 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -448,3 +448,55 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase):
self.get_success(self.store._on_shutdown())
self.assertEqual(self.store._live_read_write_lock_tokens, {})
+
+ def test_acquire_multiple_locks(self) -> None:
+ """Tests that acquiring multiple locks at once works."""
+
+ # Take out multiple locks and ensure that we can't get those locks out
+ # again.
+ lock = self.get_success(
+ self.store.try_acquire_multi_read_write_lock(
+ [("name1", "key1"), ("name2", "key2")], write=True
+ )
+ )
+ self.assertIsNotNone(lock)
+
+ assert lock is not None
+ self.get_success(lock.__aenter__())
+
+ lock2 = self.get_success(
+ self.store.try_acquire_read_write_lock("name1", "key1", write=True)
+ )
+ self.assertIsNone(lock2)
+
+ lock3 = self.get_success(
+ self.store.try_acquire_read_write_lock("name2", "key2", write=False)
+ )
+ self.assertIsNone(lock3)
+
+ # Overlapping locks attempts will fail, and won't lock any locks.
+ lock4 = self.get_success(
+ self.store.try_acquire_multi_read_write_lock(
+ [("name1", "key1"), ("name3", "key3")], write=True
+ )
+ )
+ self.assertIsNone(lock4)
+
+ lock5 = self.get_success(
+ self.store.try_acquire_read_write_lock("name3", "key3", write=True)
+ )
+ self.assertIsNotNone(lock5)
+ assert lock5 is not None
+ self.get_success(lock5.__aenter__())
+ self.get_success(lock5.__aexit__(None, None, None))
+
+ # Once we release the lock we can take out the locks again.
+ self.get_success(lock.__aexit__(None, None, None))
+
+ lock6 = self.get_success(
+ self.store.try_acquire_read_write_lock("name1", "key1", write=True)
+ )
+ self.assertIsNotNone(lock6)
+ assert lock6 is not None
+ self.get_success(lock6.__aenter__())
+ self.get_success(lock6.__aexit__(None, None, None))
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index 2fab84a529..ef06b50dbb 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -17,7 +17,6 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.databases.main.transactions import DestinationRetryTimings
from synapse.util import Clock
-from synapse.util.retryutils import MAX_RETRY_INTERVAL
from tests.unittest import HomeserverTestCase
@@ -57,8 +56,14 @@ class TransactionStoreTestCase(HomeserverTestCase):
self.get_success(d)
def test_large_destination_retry(self) -> None:
+ max_retry_interval_ms = (
+ self.hs.config.federation.destination_max_retry_interval_ms
+ )
d = self.store.set_destination_retry_timings(
- "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
+ "example.com",
+ max_retry_interval_ms,
+ max_retry_interval_ms,
+ max_retry_interval_ms,
)
self.get_success(d)
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 5f8f4e76b5..1277e1a865 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -11,12 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.retryutils import (
- MIN_RETRY_INTERVAL,
- RETRY_MULTIPLIER,
- NotRetryingDestination,
- get_retry_limiter,
-)
+from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from tests.unittest import HomeserverTestCase
@@ -42,6 +37,11 @@ class RetryLimiterTestCase(HomeserverTestCase):
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
+ min_retry_interval_ms = (
+ self.hs.config.federation.destination_min_retry_interval_ms
+ )
+ retry_multiplier = self.hs.config.federation.destination_retry_multiplier
+
self.pump(1)
try:
with limiter:
@@ -57,7 +57,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
assert new_timings is not None
self.assertEqual(new_timings.failure_ts, failure_ts)
self.assertEqual(new_timings.retry_last_ts, failure_ts)
- self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
+ self.assertEqual(new_timings.retry_interval, min_retry_interval_ms)
# now if we try again we should get a failure
self.get_failure(
@@ -68,7 +68,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
# advance the clock and try again
#
- self.pump(MIN_RETRY_INTERVAL)
+ self.pump(min_retry_interval_ms)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
@@ -87,16 +87,16 @@ class RetryLimiterTestCase(HomeserverTestCase):
self.assertEqual(new_timings.failure_ts, failure_ts)
self.assertEqual(new_timings.retry_last_ts, retry_ts)
self.assertGreaterEqual(
- new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
+ new_timings.retry_interval, min_retry_interval_ms * retry_multiplier * 0.5
)
self.assertLessEqual(
- new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
+ new_timings.retry_interval, min_retry_interval_ms * retry_multiplier * 2.0
)
#
# one more go, with success
#
- self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
+ self.reactor.advance(min_retry_interval_ms * retry_multiplier * 2.0)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
|