diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 75fb5fae6b..366b6fd5f0 100644
--- a/tests/appservice/test_api.py
+++ b/tests/appservice/test_api.py
@@ -76,7 +76,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
) -> List[JsonDict]:
# Ensure the access token is passed as a header.
- if not headers or not headers.get("Authorization"):
+ if not headers or not headers.get(b"Authorization"):
raise RuntimeError("Access token not provided")
# ... and not as a query param
if b"access_token" in args:
@@ -84,7 +84,9 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
"Access token should not be passed as a query param."
)
- self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+ self.assertEqual(
+ headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+ )
self.request_url = url
if url == URL_USER:
return SUCCESS_RESULT_USER
@@ -152,11 +154,13 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
# Ensure the access token is passed as a both a query param and in the headers.
if not args.get(b"access_token"):
raise RuntimeError("Access token should be provided in query params.")
- if not headers or not headers.get("Authorization"):
+ if not headers or not headers.get(b"Authorization"):
raise RuntimeError("Access token should be provided in auth headers.")
self.assertEqual(args.get(b"access_token"), TOKEN)
- self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+ self.assertEqual(
+ headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+ )
self.request_url = url
if url == URL_USER:
return SUCCESS_RESULT_USER
@@ -208,10 +212,12 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
) -> JsonDict:
# Ensure the access token is passed as both a header and query arg.
- if not headers.get("Authorization"):
+ if not headers.get(b"Authorization"):
raise RuntimeError("Access token not provided")
- self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+ self.assertEqual(
+ headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+ )
return RESPONSE
# We assign to a method, which mypy doesn't like.
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index f93ba5d4cf..c5700771b0 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -13,7 +13,7 @@
# limitations under the License.
import time
from typing import Any, Dict, List, Optional, cast
-from unittest.mock import AsyncMock, Mock
+from unittest.mock import Mock
import attr
import canonicaljson
@@ -189,23 +189,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key("1")
- r = self.hs.get_datastores().main.store_server_keys_json(
+ r = self.hs.get_datastores().main.store_server_keys_response(
"server9",
- get_key_id(key1),
from_server="test",
- ts_now_ms=int(time.time() * 1000),
- ts_expires_ms=1000,
+ ts_added_ms=int(time.time() * 1000),
+ verify_keys={
+ get_key_id(key1): FetchKeyResult(
+ verify_key=get_verify_key(key1), valid_until_ts=1000
+ )
+ },
# The entire response gets signed & stored, just include the bits we
# care about.
- key_json_bytes=canonicaljson.encode_canonical_json(
- {
- "verify_keys": {
- get_key_id(key1): {
- "key": encode_verify_key_base64(get_verify_key(key1))
- }
+ response_json={
+ "verify_keys": {
+ get_key_id(key1): {
+ "key": encode_verify_key_base64(get_verify_key(key1))
}
}
- ),
+ },
)
self.get_success(r)
@@ -285,34 +286,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
self.get_success(d)
- def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
- """Tests that we correctly handle key requests for keys we've stored
- with a null `ts_valid_until_ms`
- """
- mock_fetcher = Mock()
- mock_fetcher.get_keys = AsyncMock(return_value={})
-
- key1 = signedjson.key.generate_signing_key("1")
- r = self.hs.get_datastores().main.store_server_signature_keys(
- "server9",
- int(time.time() * 1000),
- # None is not a valid value in FetchKeyResult, but we're abusing this
- # API to insert null values into the database. The nulls get converted
- # to 0 when fetched in KeyStore.get_server_signature_keys.
- {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type]
- )
- self.get_success(r)
-
- json1: JsonDict = {}
- signedjson.sign.sign_json(json1, "server9", key1)
-
- # should succeed on a signed object with a 0 minimum_valid_until_ms
- d = self.hs.get_datastores().main.get_server_signature_keys(
- [("server9", get_key_id(key1))]
- )
- result = self.get_success(d)
- self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)
-
def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key("1")
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 46d022092e..a7e6cdd66a 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -422,6 +422,18 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
+ self.exclusive_as_user_2_device_id = "exclusive_as_device_2"
+ self.exclusive_as_user_2 = self.register_user("exclusive_as_user_2", "password")
+ self.exclusive_as_user_2_token = self.login(
+ "exclusive_as_user_2", "password", self.exclusive_as_user_2_device_id
+ )
+
+ self.exclusive_as_user_3_device_id = "exclusive_as_device_3"
+ self.exclusive_as_user_3 = self.register_user("exclusive_as_user_3", "password")
+ self.exclusive_as_user_3_token = self.login(
+ "exclusive_as_user_3", "password", self.exclusive_as_user_3_device_id
+ )
+
def _notify_interested_services(self) -> None:
# This is normally set in `notify_interested_services` but we need to call the
# internal async version so the reactor gets pushed to completion.
@@ -849,6 +861,119 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
for count in service_id_to_message_count.values():
self.assertEqual(count, number_of_messages)
+ @unittest.override_config(
+ {"experimental_features": {"msc2409_to_device_messages_enabled": True}}
+ )
+ def test_application_services_receive_local_to_device_for_many_users(self) -> None:
+ """
+ Test that when a user sends a to-device message to many users
+ in an application service's user namespace, the
+ application service will receive all of them.
+ """
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": "@exclusive_as_user:.+",
+ "exclusive": True,
+ },
+ {
+ "regex": "@exclusive_as_user_2:.+",
+ "exclusive": True,
+ },
+ {
+ "regex": "@exclusive_as_user_3:.+",
+ "exclusive": True,
+ },
+ ],
+ },
+ )
+
+ # Have local_user send a to-device message to exclusive_as_users
+ message_content = {"some_key": "some really interesting value"}
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.room_key_request/3",
+ content={
+ "messages": {
+ self.exclusive_as_user: {
+ self.exclusive_as_user_device_id: message_content
+ },
+ self.exclusive_as_user_2: {
+ self.exclusive_as_user_2_device_id: message_content
+ },
+ self.exclusive_as_user_3: {
+ self.exclusive_as_user_3_device_id: message_content
+ },
+ }
+ },
+ access_token=self.local_user_token,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ # Have exclusive_as_user send a to-device message to local_user
+ for user_token in [
+ self.exclusive_as_user_token,
+ self.exclusive_as_user_2_token,
+ self.exclusive_as_user_3_token,
+ ]:
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.room_key_request/4",
+ content={
+ "messages": {
+ self.local_user: {self.local_user_device_id: message_content}
+ }
+ },
+ access_token=user_token,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ # Check if our application service - that is interested in exclusive_as_user - received
+ # the to-device message as part of an AS transaction.
+ # Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS.
+ #
+ # The uninterested application service should not have been notified at all.
+ self.send_mock.assert_called_once()
+ (
+ service,
+ _events,
+ _ephemeral,
+ to_device_messages,
+ _otks,
+ _fbks,
+ _device_list_summary,
+ ) = self.send_mock.call_args[0]
+
+ # Assert that this was the same to-device message that local_user sent
+ self.assertEqual(service, interested_appservice)
+
+ # Assert expected number of messages
+ self.assertEqual(len(to_device_messages), 3)
+
+ for device_msg in to_device_messages:
+ self.assertEqual(device_msg["type"], "m.room_key_request")
+ self.assertEqual(device_msg["sender"], self.local_user)
+ self.assertEqual(device_msg["content"], message_content)
+
+ self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user)
+ self.assertEqual(
+ to_device_messages[0]["to_device_id"],
+ self.exclusive_as_user_device_id,
+ )
+
+ self.assertEqual(to_device_messages[1]["to_user_id"], self.exclusive_as_user_2)
+ self.assertEqual(
+ to_device_messages[1]["to_device_id"],
+ self.exclusive_as_user_2_device_id,
+ )
+
+ self.assertEqual(to_device_messages[2]["to_user_id"], self.exclusive_as_user_3)
+ self.assertEqual(
+ to_device_messages[2]["to_device_id"],
+ self.exclusive_as_user_3_device_id,
+ )
+
def _register_application_service(
self,
namespaces: Optional[Dict[str, Iterable[Dict]]] = None,
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 8582b1cd1e..13e2cd153a 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -197,6 +197,23 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
+ @override_config({"cas_config": {"enable_registration": False}})
+ def test_map_cas_user_does_not_register_new_user(self) -> None:
+ """Ensures new users are not registered if the enabled registration flag is disabled."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
+
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler was not called as expected
+ auth_handler.complete_sso_login.assert_not_called()
+
def _mock_request() -> Mock:
"""Returns a mock which will stand in as a SynapseRequest"""
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 55a4f95ef3..79d327499b 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -30,6 +30,7 @@ from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict, create_requester
from synapse.util import Clock
+from synapse.util.task_scheduler import TaskScheduler
from tests import unittest
from tests.unittest import override_config
@@ -49,6 +50,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
assert isinstance(handler, DeviceHandler)
self.handler = handler
self.store = hs.get_datastores().main
+ self.device_message_handler = hs.get_device_message_handler()
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -211,6 +213,51 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(res)
+ def test_delete_device_and_big_device_inbox(self) -> None:
+ """Check that deleting a big device inbox is staged and batched asynchronously."""
+ DEVICE_ID = "abc"
+ sender = "@sender:" + self.hs.hostname
+ receiver = "@receiver:" + self.hs.hostname
+ self._record_user(sender, DEVICE_ID, DEVICE_ID)
+ self._record_user(receiver, DEVICE_ID, DEVICE_ID)
+
+ # queue a bunch of messages in the inbox
+ requester = create_requester(sender, device_id=DEVICE_ID)
+ for i in range(DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
+ self.get_success(
+ self.device_message_handler.send_device_message(
+ requester, "message_type", {receiver: {"*": {"val": i}}}
+ )
+ )
+
+ # delete the device
+ self.get_success(self.handler.delete_devices(receiver, [DEVICE_ID]))
+
+ # messages should be deleted up to DEVICE_MSGS_DELETE_BATCH_LIMIT straight away
+ res = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="device_inbox",
+ keyvalues={"user_id": receiver},
+ retcols=("user_id", "device_id", "stream_id"),
+ desc="get_device_id_from_device_inbox",
+ )
+ )
+ self.assertEqual(10, len(res))
+
+ # wait for the task scheduler to do a second delete pass
+ self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)
+
+ # remaining messages should now be deleted
+ res = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="device_inbox",
+ keyvalues={"user_id": receiver},
+ retcols=("user_id", "device_id", "stream_id"),
+ desc="get_device_id_from_device_inbox",
+ )
+ )
+ self.assertEqual(0, len(res))
+
def test_update_device(self) -> None:
self._record_users()
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 21d63ab1f2..4fc0742413 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -262,7 +262,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
if (ev.type, ev.state_key)
in {("m.room.create", ""), ("m.room.member", remote_server_user_id)}
]
- for _ in range(0, 8):
+ for _ in range(8):
event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 88a16193a3..638787b029 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -21,11 +21,12 @@ from signedjson.key import generate_signing_key
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership, PresenceState
-from synapse.api.presence import UserPresenceState
+from synapse.api.presence import UserDevicePresenceState, UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.federation.sender import FederationSender
from synapse.handlers.presence import (
+ BUSY_ONLINE_TIMEOUT,
EXTERNAL_PROCESS_EXPIRY,
FEDERATION_PING_INTERVAL,
FEDERATION_TIMEOUT,
@@ -352,6 +353,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
def test_idle_timer(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -362,8 +364,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNotNone(new_state)
assert new_state is not None
@@ -376,6 +391,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
presence state into unavailable.
"""
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -386,8 +402,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNotNone(new_state)
assert new_state is not None
@@ -396,6 +425,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
def test_sync_timeout(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -406,8 +436,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNotNone(new_state)
assert new_state is not None
@@ -416,6 +459,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
def test_sync_online(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -426,9 +470,20 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
new_state = handle_timeout(
- state, is_mine=True, syncing_user_ids={user_id}, now=now
+ state,
+ is_mine=True,
+ syncing_device_ids={(user_id, device_id)},
+ user_devices={device_id: device_state},
+ now=now,
)
self.assertIsNotNone(new_state)
@@ -438,6 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
def test_federation_ping(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -449,14 +505,28 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNotNone(new_state)
self.assertEqual(state, new_state)
def test_no_timeout(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
now = 5000000
state = UserPresenceState.default(user_id)
@@ -466,8 +536,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now,
last_federation_update_ts=now,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNone(new_state)
@@ -485,8 +568,9 @@ class PresenceTimeoutTestCase(unittest.TestCase):
status_msg=status_msg,
)
+ # Note that this is a remote user so we do not have their device information.
new_state = handle_timeout(
- state, is_mine=False, syncing_user_ids=set(), now=now
+ state, is_mine=False, syncing_device_ids=set(), user_devices={}, now=now
)
self.assertIsNotNone(new_state)
@@ -496,6 +580,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
def test_last_active(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -507,8 +592,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_federation_update_ts=now,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNotNone(new_state)
self.assertEqual(state, new_state)
@@ -579,7 +677,7 @@ class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
[
(PresenceState.BUSY, PresenceState.BUSY),
(PresenceState.ONLINE, PresenceState.ONLINE),
- (PresenceState.UNAVAILABLE, PresenceState.UNAVAILABLE),
+ (PresenceState.UNAVAILABLE, PresenceState.ONLINE),
# Offline syncs don't update the state.
(PresenceState.OFFLINE, PresenceState.ONLINE),
]
@@ -800,6 +898,486 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# we should now be online
self.assertEqual(state.state, PresenceState.ONLINE)
+ @parameterized.expand(
+ # A list of tuples of 4 strings:
+ #
+ # * The presence state of device 1.
+ # * The presence state of device 2.
+ # * The expected user presence state after both devices have synced.
+ # * The expected user presence state after device 1 has idled.
+ # * The expected user presence state after device 2 has idled.
+ # * True to use workers, False a monolith.
+ [
+ (*cases, workers)
+ for workers in (False, True)
+ for cases in [
+ # If both devices have the same state, online should eventually idle.
+ # Otherwise, the state doesn't change.
+ (
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ ),
+ # If the second device has a "lower" state it should fallback to it,
+ # except for "busy" which overrides.
+ (
+ PresenceState.BUSY,
+ PresenceState.ONLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.BUSY,
+ PresenceState.UNAVAILABLE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.BUSY,
+ PresenceState.OFFLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.OFFLINE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ # If the second device has a "higher" state it should override.
+ (
+ PresenceState.ONLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ ]
+ ],
+ name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}",
+ )
+ @unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
+ def test_set_presence_from_syncing_multi_device(
+ self,
+ dev_1_state: str,
+ dev_2_state: str,
+ expected_state_1: str,
+ expected_state_2: str,
+ expected_state_3: str,
+ test_with_workers: bool,
+ ) -> None:
+ """
+ Test the behaviour of multiple devices syncing at the same time.
+
+ Roughly the user's presence state should be set to the "highest" priority
+ of all the devices. When a device then goes offline its state should be
+ discarded and the next highest should win.
+
+ Note that these tests use the idle timer (and don't close the syncs), it
+ is unlikely that a *single* sync would last this long, but is close enough
+ to continually syncing with that current state.
+ """
+ user_id = f"@test:{self.hs.config.server.server_name}"
+
+ # By default, we call /sync against the main process.
+ worker_presence_handler = self.presence_handler
+ if test_with_workers:
+ # Create a worker and use it to handle /sync traffic instead.
+ # This is used to test that presence changes get replicated from workers
+ # to the main process correctly.
+ worker_to_sync_against = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "synchrotron"}
+ )
+ worker_presence_handler = worker_to_sync_against.get_presence_handler()
+
+ # 1. Sync with the first device.
+ self.get_success(
+ worker_presence_handler.user_syncing(
+ user_id,
+ "dev-1",
+ affect_presence=dev_1_state != PresenceState.OFFLINE,
+ presence_state=dev_1_state,
+ ),
+ by=0.01,
+ )
+
+ # 2. Wait half the idle timer.
+ self.reactor.advance(IDLE_TIMER / 1000 / 2)
+ self.reactor.pump([0.1])
+
+ # 3. Sync with the second device.
+ self.get_success(
+ worker_presence_handler.user_syncing(
+ user_id,
+ "dev-2",
+ affect_presence=dev_2_state != PresenceState.OFFLINE,
+ presence_state=dev_2_state,
+ ),
+ by=0.01,
+ )
+
+ # 4. Assert the expected presence state.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_1)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_1)
+
+ # When testing with workers, make another random sync (with any *different*
+ # user) to keep the process information from expiring.
+ #
+ # This is due to EXTERNAL_PROCESS_EXPIRY being equivalent to IDLE_TIMER.
+ if test_with_workers:
+ with self.get_success(
+ worker_presence_handler.user_syncing(
+ f"@other-user:{self.hs.config.server.server_name}",
+ "dev-3",
+ affect_presence=True,
+ presence_state=PresenceState.ONLINE,
+ ),
+ by=0.01,
+ ):
+ pass
+
+ # 5. Advance such that the first device should be discarded (the idle timer),
+ # then pump so _handle_timeouts function to called.
+ self.reactor.advance(IDLE_TIMER / 1000 / 2)
+ self.reactor.pump([0.01])
+
+ # 6. Assert the expected presence state.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_2)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_2)
+
+ # 7. Advance such that the second device should be discarded (half the idle timer),
+ # then pump so _handle_timeouts function to called.
+ self.reactor.advance(IDLE_TIMER / 1000 / 2)
+ self.reactor.pump([0.1])
+
+ # 8. The devices are still "syncing" (the sync context managers were never
+ # closed), so might idle.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_3)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_3)
+
+ @parameterized.expand(
+ # A list of tuples of 4 strings:
+ #
+ # * The presence state of device 1.
+ # * The presence state of device 2.
+ # * The expected user presence state after both devices have synced.
+ # * The expected user presence state after device 1 has stopped syncing.
+ # * True to use workers, False a monolith.
+ [
+ (*cases, workers)
+ for workers in (False, True)
+ for cases in [
+ # If both devices have the same state, nothing exciting should happen.
+ (
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ ),
+ # If the second device has a "lower" state it should fallback to it,
+ # except for "busy" which overrides.
+ (
+ PresenceState.BUSY,
+ PresenceState.ONLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.BUSY,
+ PresenceState.UNAVAILABLE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.BUSY,
+ PresenceState.OFFLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.OFFLINE,
+ PresenceState.ONLINE,
+ PresenceState.OFFLINE,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
+ ),
+ # If the second device has a "higher" state it should override.
+ (
+ PresenceState.ONLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ ]
+ ],
+ name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}",
+ )
+ @unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
+ def test_set_presence_from_non_syncing_multi_device(
+ self,
+ dev_1_state: str,
+ dev_2_state: str,
+ expected_state_1: str,
+ expected_state_2: str,
+ test_with_workers: bool,
+ ) -> None:
+ """
+ Test the behaviour of multiple devices syncing at the same time.
+
+ Roughly the user's presence state should be set to the "highest" priority
+ of all the devices. When a device then goes offline its state should be
+ discarded and the next highest should win.
+
+ Note that these tests use the idle timer (and don't close the syncs), it
+ is unlikely that a *single* sync would last this long, but is close enough
+ to continually syncing with that current state.
+ """
+ user_id = f"@test:{self.hs.config.server.server_name}"
+
+ # By default, we call /sync against the main process.
+ worker_presence_handler = self.presence_handler
+ if test_with_workers:
+ # Create a worker and use it to handle /sync traffic instead.
+ # This is used to test that presence changes get replicated from workers
+ # to the main process correctly.
+ worker_to_sync_against = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "synchrotron"}
+ )
+ worker_presence_handler = worker_to_sync_against.get_presence_handler()
+
+ # 1. Sync with the first device.
+ sync_1 = self.get_success(
+ worker_presence_handler.user_syncing(
+ user_id,
+ "dev-1",
+ affect_presence=dev_1_state != PresenceState.OFFLINE,
+ presence_state=dev_1_state,
+ ),
+ by=0.1,
+ )
+
+ # 2. Sync with the second device.
+ sync_2 = self.get_success(
+ worker_presence_handler.user_syncing(
+ user_id,
+ "dev-2",
+ affect_presence=dev_2_state != PresenceState.OFFLINE,
+ presence_state=dev_2_state,
+ ),
+ by=0.1,
+ )
+
+ # 3. Assert the expected presence state.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_1)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_1)
+
+ # 4. Disconnect the first device.
+ with sync_1:
+ pass
+
+ # 5. Advance such that the first device should be discarded (the sync timeout),
+ # then pump so _handle_timeouts function to called.
+ self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000)
+ self.reactor.pump([5])
+
+ # 6. Assert the expected presence state.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_2)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_2)
+
+ # 7. Disconnect the second device.
+ with sync_2:
+ pass
+
+ # 8. Advance such that the second device should be discarded (the sync timeout),
+ # then pump so _handle_timeouts function to called.
+ if dev_1_state == PresenceState.BUSY or dev_2_state == PresenceState.BUSY:
+ timeout = BUSY_ONLINE_TIMEOUT
+ else:
+ timeout = SYNC_ONLINE_TIMEOUT
+ self.reactor.advance(timeout / 1000)
+ self.reactor.pump([5])
+
+ # 9. There are no more devices, should be offline.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.OFFLINE)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.OFFLINE)
+
def test_set_presence_from_syncing_keeps_status(self) -> None:
"""Test that presence set by syncing retains status message"""
status_msg = "I'm here!"
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 0d17f2fe5b..9f63fa6fa8 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -15,7 +15,7 @@ import base64
import logging
import os
from typing import Generator, List, Optional, cast
-from unittest.mock import AsyncMock, patch
+from unittest.mock import AsyncMock, call, patch
import treq
from netaddr import IPSet
@@ -651,9 +651,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# .well-known request fails.
self.reactor.pump((0.4,))
- # now there should be a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv1"
+ # now there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv1"), call(b"_matrix._tcp.testserv1")]
)
# we should fall back to a direct connection
@@ -737,9 +737,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# .well-known request fails.
self.reactor.pump((0.4,))
- # now there should be a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv"
+ # now there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
)
# we should fall back to a direct connection
@@ -788,9 +788,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
content=b'{ "m.server": "target-server" }',
)
- # there should be a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.target-server"
+ # there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [
+ call(b"_matrix-fed._tcp.target-server"),
+ call(b"_matrix._tcp.target-server"),
+ ]
)
# now we should get a connection to the target server
@@ -878,9 +881,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
- # there should be a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.target-server"
+ # there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [
+ call(b"_matrix-fed._tcp.target-server"),
+ call(b"_matrix._tcp.target-server"),
+ ]
)
# now we should get a connection to the target server
@@ -942,9 +948,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
client_factory, expected_sni=b"testserv", content=b"NOT JSON"
)
- # now there should be a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv"
+ # now there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
)
# we should fall back to a direct connection
@@ -1016,14 +1022,14 @@ class MatrixFederationAgentTests(unittest.TestCase):
# there should be no requests
self.assertEqual(len(http_proto.requests), 0)
- # and there should be a SRV lookup instead
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv"
+ # and there should be two SRV lookups instead
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
)
def test_get_hostname_srv(self) -> None:
"""
- Test the behaviour when there is a single SRV record
+ Test the behaviour when there is a single SRV record for _matrix-fed.
"""
self.agent = self._make_agent()
@@ -1039,7 +1045,51 @@ class MatrixFederationAgentTests(unittest.TestCase):
# the request for a .well-known will have failed with a DNS lookup error.
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv"
+ b"_matrix-fed._tcp.testserv"
+ )
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(client_factory, expected_sni=b"testserv")
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_hostname_srv_legacy(self) -> None:
+ """
+ Test the behaviour when there is a single SRV record for _matrix.
+ """
+ self.agent = self._make_agent()
+
+ # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+ self.mock_resolver.resolve_service.side_effect = [
+ [],
+ [Server(host=b"srvtarget", port=8443)],
+ ]
+ self.reactor.lookups["srvtarget"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # the request for a .well-known will have failed with a DNS lookup error.
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
)
# Make sure treq is trying to connect
@@ -1065,7 +1115,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_get_well_known_srv(self) -> None:
"""Test the behaviour when the .well-known redirects to a place where there
- is a SRV.
+ is a _matrix-fed SRV record.
"""
self.agent = self._make_agent()
@@ -1096,7 +1146,72 @@ class MatrixFederationAgentTests(unittest.TestCase):
# there should be a SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.target-server"
+ b"_matrix-fed._tcp.target-server"
+ )
+
+ # now we should get a connection to the target of the SRV record
+ self.assertEqual(len(clients), 2)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[1]
+ self.assertEqual(host, "5.6.7.8")
+ self.assertEqual(port, 8443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, expected_sni=b"target-server"
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"host"), [b"target-server"]
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_well_known_srv_legacy(self) -> None:
+ """Test the behaviour when the .well-known redirects to a place where there
+ is a _matrix SRV record.
+ """
+ self.agent = self._make_agent()
+
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.reactor.lookups["srvtarget"] = "5.6.7.8"
+
+ test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+ self.mock_resolver.resolve_service.side_effect = [
+ [],
+ [Server(host=b"srvtarget", port=8443)],
+ ]
+
+ self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ content=b'{ "m.server": "target-server" }',
+ )
+
+ # there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [
+ call(b"_matrix-fed._tcp.target-server"),
+ call(b"_matrix._tcp.target-server"),
+ ]
)
# now we should get a connection to the target of the SRV record
@@ -1158,8 +1273,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.4,))
# now there should have been a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.xn--bcher-kva.com"
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [
+ call(b"_matrix-fed._tcp.xn--bcher-kva.com"),
+ call(b"_matrix._tcp.xn--bcher-kva.com"),
+ ]
)
# We should fall back to port 8448
@@ -1188,7 +1306,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.successResultOf(test_d)
def test_idna_srv_target(self) -> None:
- """test the behaviour when the target of a SRV record has idna chars"""
+ """test the behaviour when the target of a _matrix-fed SRV record has idna chars"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.return_value = [
@@ -1204,7 +1322,57 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.xn--bcher-kva.com"
+ b"_matrix-fed._tcp.xn--bcher-kva.com"
+ )
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, expected_sni=b"xn--bcher-kva.com"
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"]
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_idna_srv_target_legacy(self) -> None:
+ """test the behaviour when the target of a _matrix SRV record has idna chars"""
+ self.agent = self._make_agent()
+
+ # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+ self.mock_resolver.resolve_service.side_effect = [
+ [],
+ [Server(host=b"xn--trget-3qa.com", port=8443)],
+ ] # târget.com
+ self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
+
+ test_d = self._make_get_request(
+ b"matrix-federation://xn--bcher-kva.com/foo/bar"
+ )
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [
+ call(b"_matrix-fed._tcp.xn--bcher-kva.com"),
+ call(b"_matrix._tcp.xn--bcher-kva.com"),
+ ]
)
# Make sure treq is trying to connect
@@ -1394,7 +1562,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertIsNone(r.delegated_server)
def test_srv_fallbacks(self) -> None:
- """Test that other SRV results are tried if the first one fails."""
+ """Test that other SRV results are tried if the first one fails for _matrix-fed SRV."""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.return_value = [
@@ -1409,7 +1577,67 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv"
+ b"_matrix-fed._tcp.testserv"
+ )
+
+ # We should see an attempt to connect to the first server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # Fonx the connection
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # There's a 300ms delay in HostnameEndpoint
+ self.reactor.pump((0.4,))
+
+ # Hasn't failed yet
+ self.assertNoResult(test_d)
+
+ # We shouldnow see an attempt to connect to the second server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8444)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(client_factory, expected_sni=b"testserv")
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_srv_fallbacks_legacy(self) -> None:
+ """Test that other SRV results are tried if the first one fails for _matrix SRV."""
+ self.agent = self._make_agent()
+
+ # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+ self.mock_resolver.resolve_service.side_effect = [
+ [],
+ [
+ Server(host=b"target.com", port=8443),
+ Server(host=b"target.com", port=8444),
+ ],
+ ]
+ self.reactor.lookups["target.com"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
)
# We should see an attempt to connect to the first server
@@ -1449,6 +1677,43 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
+ def test_srv_no_fallback_to_legacy(self) -> None:
+ """Test that _matrix SRV results are not tried if the _matrix-fed one fails."""
+ self.agent = self._make_agent()
+
+ # Return a failing entry for _matrix-fed.
+ self.mock_resolver.resolve_service.side_effect = [
+ [Server(host=b"target.com", port=8443)],
+ [],
+ ]
+ self.reactor.lookups["target.com"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Only the _matrix-fed is checked, _matrix is ignored.
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix-fed._tcp.testserv"
+ )
+
+ # We should see an attempt to connect to the first server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # Fonx the connection
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # There's a 300ms delay in HostnameEndpoint
+ self.reactor.pump((0.4,))
+
+ # Failed to resolve a server.
+ self.assertFailure(test_d, Exception)
+
class TestCachePeriodFromHeaders(unittest.TestCase):
def test_cache_control(self) -> None:
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index 5191e31a8a..45eac100bf 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -78,11 +78,11 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
logger = self.get_logger(handler)
# Send some debug messages
- for i in range(0, 3):
+ for i in range(3):
logger.debug("debug %s" % (i,))
# Send a bunch of useful messages
- for i in range(0, 7):
+ for i in range(7):
logger.info("info %s" % (i,))
# The last debug message pushes it past the maximum buffer
@@ -108,15 +108,15 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
logger = self.get_logger(handler)
# Send some debug messages
- for i in range(0, 3):
+ for i in range(3):
logger.debug("debug %s" % (i,))
# Send a bunch of useful messages
- for i in range(0, 10):
+ for i in range(10):
logger.warning("warn %s" % (i,))
# Send a bunch of info messages
- for i in range(0, 3):
+ for i in range(3):
logger.info("info %s" % (i,))
# The last debug message pushes it past the maximum buffer
@@ -144,7 +144,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
logger = self.get_logger(handler)
# Send a bunch of useful messages
- for i in range(0, 20):
+ for i in range(20):
logger.warning("warn %s" % (i,))
# Allow the reconnection
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 4b5c96aeae..73a430ddc6 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -13,10 +13,12 @@
# limitations under the License.
import email.message
import os
+from http import HTTPStatus
from typing import Any, Dict, List, Sequence, Tuple
import attr
import pkg_resources
+from parameterized import parameterized
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
@@ -25,9 +27,11 @@ import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
from synapse.push.emailpusher import EmailPusher
from synapse.rest.client import login, room
+from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
from synapse.server import HomeServer
from synapse.util import Clock
+from tests.server import FakeSite, make_request
from tests.unittest import HomeserverTestCase
@@ -175,6 +179,57 @@ class EmailPusherTests(HomeserverTestCase):
self._check_for_mail()
+ @parameterized.expand([(False,), (True,)])
+ def test_unsubscribe(self, use_post: bool) -> None:
+ # Create a simple room with two users
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+ )
+ self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+ # The other user sends a single message.
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+
+ # We should get emailed about that message
+ args, kwargs = self._check_for_mail()
+
+ # That email should contain an unsubscribe link in the body and header.
+ msg: bytes = args[5]
+
+ # Multipart: plain text, base 64 encoded; html, base 64 encoded
+ multipart_msg = email.message_from_bytes(msg)
+ txt = multipart_msg.get_payload()[0].get_payload(decode=True).decode()
+ html = multipart_msg.get_payload()[1].get_payload(decode=True).decode()
+ self.assertIn("/_synapse/client/unsubscribe", txt)
+ self.assertIn("/_synapse/client/unsubscribe", html)
+
+ # The unsubscribe headers should exist.
+ assert multipart_msg.get("List-Unsubscribe") is not None
+ self.assertIsNotNone(multipart_msg.get("List-Unsubscribe-Post"))
+
+ # Open the unsubscribe link.
+ unsubscribe_link = multipart_msg["List-Unsubscribe"].strip("<>")
+ unsubscribe_resource = UnsubscribeResource(self.hs)
+ channel = make_request(
+ self.reactor,
+ FakeSite(unsubscribe_resource, self.reactor),
+ "POST" if use_post else "GET",
+ unsubscribe_link,
+ shorthand=False,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ # Ensure the pusher was removed.
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
+ )
+ self.assertEqual(pushers, [])
+
def test_invite_sends_email(self) -> None:
# Create a room and invite the user to it
room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
diff --git a/tests/replication/tcp/streams/test_to_device.py b/tests/replication/tcp/streams/test_to_device.py
index fb9eac668f..ab379e8cf1 100644
--- a/tests/replication/tcp/streams/test_to_device.py
+++ b/tests/replication/tcp/streams/test_to_device.py
@@ -49,7 +49,7 @@ class ToDeviceStreamTestCase(BaseStreamTestCase):
# add messages to the device inbox for user1 up until the
# limit defined for a stream update batch
- for i in range(0, _STREAM_UPDATE_TARGET_ROW_COUNT):
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT):
msg["content"] = {"device": {}}
messages = {user1: {"device": msg}}
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 4c7864c629..0e2824d1b5 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -510,7 +510,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
Args:
number_destinations: Number of destinations to be created
"""
- for i in range(0, number_destinations):
+ for i in range(number_destinations):
dest = f"sub{i}.example.com"
self._create_destination(dest, 50, 50, 50, 100)
@@ -690,7 +690,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
self._check_fields(channel_desc.json_body["rooms"])
# test that both lists have different directions
- for i in range(0, number_rooms):
+ for i in range(number_rooms):
self.assertEqual(
channel_asc.json_body["rooms"][i]["room_id"],
channel_desc.json_body["rooms"][number_rooms - 1 - i]["room_id"],
@@ -777,7 +777,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
Args:
number_rooms: Number of rooms to be created
"""
- for _ in range(0, number_rooms):
+ for _ in range(number_rooms):
room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
)
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index e9f495e206..cffbda9a7d 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -31,6 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
+from synapse.storage._base import db_to_json
from synapse.types import JsonDict, UserID
from synapse.util import Clock
@@ -134,6 +135,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
+ # Check that the UI Auth information doesn't store the password in the database.
+ #
+ # Note that we don't have the UI Auth session ID, so just pull out the single
+ # row.
+ ui_auth_data = self.get_success(
+ self.store.db_pool.simple_select_one(
+ "ui_auth_sessions", keyvalues={}, retcols=("clientdict",)
+ )
+ )
+ client_dict = db_to_json(ui_auth_data["clientdict"])
+ self.assertNotIn("new_password", client_dict)
+
@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_email(self) -> None:
"""Test that we ratelimit /requestToken for the same email."""
@@ -562,7 +575,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
# create a bunch of users and add keys for them
users = []
- for i in range(0, 20):
+ for i in range(20):
user_id = self.register_user("missPiggy" + str(i), "test")
users.append((user_id,))
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index a2a6589564..768d7ad4c2 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -176,10 +176,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_ratelimiting_per_address(self) -> None:
# Create different users so we're sure not to be bothered by the per-user
# ratelimiter.
- for i in range(0, 6):
+ for i in range(6):
self.register_user("kermit" + str(i), "monkey")
- for i in range(0, 6):
+ for i in range(6):
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
@@ -228,7 +228,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_ratelimiting_per_account(self) -> None:
self.register_user("kermit", "monkey")
- for i in range(0, 6):
+ for i in range(6):
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
@@ -277,7 +277,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
self.register_user("kermit", "monkey")
- for i in range(0, 6):
+ for i in range(6):
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index c33393dc28..ba4e017a0e 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -169,7 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting_guest(self) -> None:
- for i in range(0, 6):
+ for i in range(6):
url = self.url + b"?kind=guest"
channel = self.make_request(b"POST", url, b"{}")
@@ -187,7 +187,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self) -> None:
- for i in range(0, 6):
+ for i in range(6):
request_data = {
"username": "kermit" + str(i),
"password": "monkey",
@@ -1223,7 +1223,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
def test_GET_ratelimiting(self) -> None:
token = "1234"
- for i in range(0, 6):
+ for i in range(6):
channel = self.make_request(
b"GET",
f"{self.url}?token={token}",
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 650b4941ba..35f77052a7 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -382,7 +382,7 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase):
self.get_success(lock.__aenter__())
# Wait for ages with the lock, we should not be able to get the lock.
- for _ in range(0, 10):
+ for _ in range(10):
self.reactor.advance((_RENEWAL_INTERVAL_MS / 1000))
lock2 = self.get_success(
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 48ebfadaab..b55dd07f14 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -664,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Add a bunch of state so that it takes multiple iterations of the
# background update to process the room.
- for i in range(0, 150):
+ for i in range(150):
self.helper.send_state(
room_id, event_type="m.test", body={"index": i}, tok=self.token
)
@@ -718,12 +718,12 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Add a bunch of state so that it takes multiple iterations of the
# background update to process the room.
- for i in range(0, 150):
+ for i in range(150):
self.helper.send_state(
room_id1, event_type="m.test", body={"index": i}, tok=self.token
)
- for i in range(0, 150):
+ for i in range(150):
self.helper.send_state(
room_id2, event_type="m.test", body={"index": i}, tok=self.token
)
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 7a4ecab2d5..d3e20f44b2 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -227,7 +227,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
(room_id, event_id),
)
- for i in range(0, 20):
+ for i in range(20):
self.get_success(
self.store.db_pool.runInteraction("insert", insert_event, i)
)
@@ -235,7 +235,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# this should get the last ten
r = self.get_success(self.store.get_prev_events_for_room(room_id))
self.assertEqual(10, len(r))
- for i in range(0, 10):
+ for i in range(10):
self.assertEqual("$event_%i:local" % (19 - i), r[i])
def test_get_rooms_with_many_extremities(self) -> None:
@@ -277,7 +277,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
(room_id, event_id),
)
- for i in range(0, 20):
+ for i in range(20):
self.get_success(
self.store.db_pool.runInteraction("insert", insert_event, i, room1)
)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
deleted file mode 100644
index 5d7c13e6d0..0000000000
--- a/tests/storage/test_keys.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import signedjson.key
-import signedjson.types
-import unpaddedbase64
-
-from synapse.storage.keys import FetchKeyResult
-
-import tests.unittest
-
-
-def decode_verify_key_base64(
- key_id: str, key_base64: str
-) -> signedjson.types.VerifyKey:
- key_bytes = unpaddedbase64.decode_base64(key_base64)
- return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
-
-
-KEY_1 = decode_verify_key_base64(
- "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
-)
-KEY_2 = decode_verify_key_base64(
- "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
-)
-
-
-class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
- def test_get_server_signature_keys(self) -> None:
- store = self.hs.get_datastores().main
-
- key_id_1 = "ed25519:key1"
- key_id_2 = "ed25519:KEY_ID_2"
- self.get_success(
- store.store_server_signature_keys(
- "from_server",
- 10,
- {
- ("server1", key_id_1): FetchKeyResult(KEY_1, 100),
- ("server1", key_id_2): FetchKeyResult(KEY_2, 200),
- },
- )
- )
-
- res = self.get_success(
- store.get_server_signature_keys(
- [
- ("server1", key_id_1),
- ("server1", key_id_2),
- ("server1", "ed25519:key3"),
- ]
- )
- )
-
- self.assertEqual(len(res.keys()), 3)
- res1 = res[("server1", key_id_1)]
- self.assertEqual(res1.verify_key, KEY_1)
- self.assertEqual(res1.verify_key.version, "key1")
- self.assertEqual(res1.valid_until_ts, 100)
-
- res2 = res[("server1", key_id_2)]
- self.assertEqual(res2.verify_key, KEY_2)
- # version comes from the ID it was stored with
- self.assertEqual(res2.verify_key.version, "KEY_ID_2")
- self.assertEqual(res2.valid_until_ts, 200)
-
- # non-existent result gives None
- self.assertIsNone(res[("server1", "ed25519:key3")])
-
- def test_cache(self) -> None:
- """Check that updates correctly invalidate the cache."""
-
- store = self.hs.get_datastores().main
-
- key_id_1 = "ed25519:key1"
- key_id_2 = "ed25519:key2"
-
- self.get_success(
- store.store_server_signature_keys(
- "from_server",
- 0,
- {
- ("srv1", key_id_1): FetchKeyResult(KEY_1, 100),
- ("srv1", key_id_2): FetchKeyResult(KEY_2, 200),
- },
- )
- )
-
- res = self.get_success(
- store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- )
- self.assertEqual(len(res.keys()), 2)
-
- res1 = res[("srv1", key_id_1)]
- self.assertEqual(res1.verify_key, KEY_1)
- self.assertEqual(res1.valid_until_ts, 100)
-
- res2 = res[("srv1", key_id_2)]
- self.assertEqual(res2.verify_key, KEY_2)
- self.assertEqual(res2.valid_until_ts, 200)
-
- # we should be able to look up the same thing again without a db hit
- res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)]))
- self.assertEqual(len(res.keys()), 1)
- self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
-
- new_key_2 = signedjson.key.get_verify_key(
- signedjson.key.generate_signing_key("key2")
- )
- d = store.store_server_signature_keys(
- "from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)}
- )
- self.get_success(d)
-
- res = self.get_success(
- store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- )
- self.assertEqual(len(res.keys()), 2)
-
- res1 = res[("srv1", key_id_1)]
- self.assertEqual(res1.verify_key, KEY_1)
- self.assertEqual(res1.valid_until_ts, 100)
-
- res2 = res[("srv1", key_id_2)]
- self.assertEqual(res2.verify_key, new_key_2)
- self.assertEqual(res2.valid_until_ts, 300)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index fe5bb77913..95f99f4130 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -82,7 +82,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.db_pool.runInteraction("", f))
- for i in range(0, 70):
+ for i in range(70):
self.get_success(
self.store.db_pool.simple_insert(
"profiles",
@@ -115,7 +115,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
)
expected_values = []
- for i in range(0, 70):
+ for i in range(70):
expected_values.append((f"@hello{i:02}:{self.hs.hostname}",))
res = self.get_success(
diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index 15ea4770bd..22f074982f 100644
--- a/tests/storage/test_txn_limit.py
+++ b/tests/storage/test_txn_limit.py
@@ -38,5 +38,5 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
db_pool = self.hs.get_datastores().databases[0]
# force txn limit to roll over at least once
- for _ in range(0, 1001):
+ for _ in range(1001):
self.get_success_or_raise(db_pool.runInteraction("test_select", do_select))
diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py
index bab802f56e..d4637d9d1e 100644
--- a/tests/storage/test_user_filters.py
+++ b/tests/storage/test_user_filters.py
@@ -45,7 +45,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.db_pool.runInteraction("", f))
- for i in range(0, 70):
+ for i in range(70):
self.get_success(
self.store.db_pool.simple_insert(
"user_filters",
@@ -82,7 +82,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
)
expected_values = []
- for i in range(0, 70):
+ for i in range(70):
expected_values.append((f"@hello{i:02}:{self.hs.hostname}",))
res = self.get_success(
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index a46c29ddf4..434902c3f0 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -51,12 +51,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# before we do that, we persist some other events to act as state.
self._inject_visibility("@admin:hs", "joined")
- for i in range(0, 10):
+ for i in range(10):
self._inject_room_member("@resident%i:hs" % i)
events_to_filter = []
- for i in range(0, 10):
+ for i in range(10):
user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
evt = self._inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
@@ -74,7 +74,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
)
# the result should be 5 redacted events, and 5 unredacted events.
- for i in range(0, 5):
+ for i in range(5):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertNotIn("a", filtered[i].content)
@@ -177,7 +177,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
)
)
- for i in range(0, len(events_to_filter)):
+ for i in range(len(events_to_filter)):
self.assertEqual(
events_to_filter[i].event_id,
filtered[i].event_id,
diff --git a/tests/unittest.py b/tests/unittest.py
index 5d3640d8ac..dbaff361b4 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -70,6 +70,7 @@ from synapse.logging.context import (
)
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
+from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
@@ -858,23 +859,22 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
self.get_success(
- hs.get_datastores().main.store_server_keys_json(
+ hs.get_datastores().main.store_server_keys_response(
self.OTHER_SERVER_NAME,
- verify_key_id,
from_server=self.OTHER_SERVER_NAME,
- ts_now_ms=clock.time_msec(),
- ts_expires_ms=clock.time_msec() + 10000,
- key_json_bytes=canonicaljson.encode_canonical_json(
- {
- "verify_keys": {
- verify_key_id: {
- "key": signedjson.key.encode_verify_key_base64(
- verify_key
- )
- }
+ ts_added_ms=clock.time_msec(),
+ verify_keys={
+ verify_key_id: FetchKeyResult(
+ verify_key=verify_key, valid_until_ts=clock.time_msec() + 10000
+ ),
+ },
+ response_json={
+ "verify_keys": {
+ verify_key_id: {
+ "key": signedjson.key.encode_verify_key_base64(verify_key)
}
}
- ),
+ },
)
)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 064f4987df..168419f440 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -623,14 +623,14 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
a = A()
- for k in range(0, 12):
+ for k in range(12):
yield a.func(k)
self.assertEqual(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
- for k in range(0, 12):
+ for k in range(12):
yield a.func(k)
self.assertTrue(
|