diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py
index f57c813a58..5c91031746 100644
--- a/tests/config/test_oauth_delegation.py
+++ b/tests/config/test_oauth_delegation.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
from unittest.mock import Mock
from synapse.config import ConfigError
@@ -167,6 +168,21 @@ class MSC3861OAuthDelegation(TestCase):
with self.assertRaises(ConfigError):
self.parse_config()
+ def test_user_consent_cannot_be_enabled(self) -> None:
+ tmpdir = self.mktemp()
+ os.mkdir(tmpdir)
+ self.config_dict["user_consent"] = {
+ "require_at_registration": True,
+ "version": "1",
+ "template_dir": tmpdir,
+ "server_notice_content": {
+ "msgtype": "m.text",
+ "body": "foo",
+ },
+ }
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
def test_password_config_cannot_be_enabled(self) -> None:
self.config_dict["password_config"] = {"enabled": True}
with self.assertRaises(ConfigError):
@@ -255,3 +271,8 @@ class MSC3861OAuthDelegation(TestCase):
self.config_dict["session_lifetime"] = "24h"
with self.assertRaises(ConfigError):
self.parse_config()
+
+ def test_enable_3pid_changes_cannot_be_enabled(self) -> None:
+ self.config_dict["enable_3pid_changes"] = True
+ with self.assertRaises(ConfigError):
+ self.parse_config()
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index 82c26e303f..b891e84690 100644
--- a/tests/handlers/test_oauth_delegation.py
+++ b/tests/handlers/test_oauth_delegation.py
@@ -14,7 +14,7 @@
from http import HTTPStatus
from typing import Any, Dict, Union
-from unittest.mock import ANY, Mock
+from unittest.mock import ANY, AsyncMock, Mock
from urllib.parse import parse_qs
from signedjson.key import (
@@ -340,6 +340,41 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
get_awaitable_result(self.auth.is_server_admin(requester)), False
)
+ def test_active_user_admin_impersonation(self) -> None:
+ """The handler should return a requester with normal user rights
+ and an user ID matching the one specified in query param `user_id`"""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ impersonated_user_id = f"@{USERNAME}:{SERVER_NAME}"
+ request.args[b"_oidc_admin_impersonate_user_id"] = [
+ impersonated_user_id.encode("ascii")
+ ]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), impersonated_user_id)
+ self.assertEqual(requester.is_guest, False)
+ self.assertEqual(requester.device_id, None)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), False
+ )
+
def test_active_user_with_device(self) -> None:
"""The handler should return a requester with normal user rights and a device ID."""
@@ -553,6 +588,38 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
)
self.assertEqual(self.http_client.request.call_count, 2)
+ def test_revocation_endpoint(self) -> None:
+ # mock introspection response and then admin verification response
+ self.http_client.request = AsyncMock(
+ side_effect=[
+ FakeResponse.json(
+ code=200, payload={"active": True, "jti": "open_sesame"}
+ ),
+ FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
+ "username": USERNAME,
+ },
+ ),
+ ]
+ )
+
+ # cache a token to delete
+ introspection_token = self.get_success(
+ self.auth._introspect_token("open_sesame") # type: ignore[attr-defined]
+ )
+ self.assertEqual(self.auth._token_cache.get("open_sesame"), introspection_token) # type: ignore[attr-defined]
+
+ # delete the revoked token
+ introspection_token_id = "open_sesame"
+ url = f"/_synapse/admin/v1/OIDC_token_revocation/{introspection_token_id}"
+ channel = self.make_request("DELETE", url, access_token="mockAccessToken")
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(self.auth._token_cache.get("open_sesame"), None) # type: ignore[attr-defined]
+
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index fd66d573d2..1f483eb75a 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -514,6 +514,9 @@ class PresenceTimeoutTestCase(unittest.TestCase):
class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
+ user_id = "@test:server"
+ user_id_obj = UserID.from_string(user_id)
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
@@ -523,12 +526,11 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
we time out their syncing users presence.
"""
process_id = "1"
- user_id = "@test:server"
# Notify handler that a user is now syncing.
self.get_success(
self.presence_handler.update_external_syncs_row(
- process_id, user_id, True, self.clock.time_msec()
+ process_id, self.user_id, True, self.clock.time_msec()
)
)
@@ -536,48 +538,37 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# stopped syncing that their presence state doesn't get timed out.
self.reactor.advance(EXTERNAL_PROCESS_EXPIRY / 2)
- state = self.get_success(
- self.presence_handler.get_state(UserID.from_string(user_id))
- )
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.assertEqual(state.state, PresenceState.ONLINE)
# Check that if the external process timeout fires, then the syncing
# user gets timed out
self.reactor.advance(EXTERNAL_PROCESS_EXPIRY)
- state = self.get_success(
- self.presence_handler.get_state(UserID.from_string(user_id))
- )
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.assertEqual(state.state, PresenceState.OFFLINE)
def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None:
"""Test that if a user doesn't update the records for a while
users presence goes `OFFLINE` because of timeout and `status_msg` remains.
"""
- user_id = "@test:server"
status_msg = "I'm here!"
# Mark user as online
- self._set_presencestate_with_status_msg(
- user_id, PresenceState.ONLINE, status_msg
- )
+ self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
# Check that if we wait a while without telling the handler the user has
# stopped syncing that their presence state doesn't get timed out.
self.reactor.advance(SYNC_ONLINE_TIMEOUT / 2)
- state = self.get_success(
- self.presence_handler.get_state(UserID.from_string(user_id))
- )
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.assertEqual(state.state, PresenceState.ONLINE)
self.assertEqual(state.status_msg, status_msg)
# Check that if the timeout fires, then the syncing user gets timed out
self.reactor.advance(SYNC_ONLINE_TIMEOUT)
- state = self.get_success(
- self.presence_handler.get_state(UserID.from_string(user_id))
- )
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
# status_msg should remain even after going offline
self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, status_msg)
@@ -586,24 +577,19 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
"""Test that if a user change presence manually to `OFFLINE`
and no status is set, that `status_msg` is `None`.
"""
- user_id = "@test:server"
status_msg = "I'm here!"
# Mark user as online
- self._set_presencestate_with_status_msg(
- user_id, PresenceState.ONLINE, status_msg
- )
+ self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
# Mark user as offline
self.get_success(
self.presence_handler.set_state(
- UserID.from_string(user_id), {"presence": PresenceState.OFFLINE}
+ self.user_id_obj, {"presence": PresenceState.OFFLINE}
)
)
- state = self.get_success(
- self.presence_handler.get_state(UserID.from_string(user_id))
- )
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, None)
@@ -611,41 +597,31 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
"""Test that if a user change presence manually to `OFFLINE`
and a status is set, that `status_msg` appears.
"""
- user_id = "@test:server"
status_msg = "I'm here!"
# Mark user as online
- self._set_presencestate_with_status_msg(
- user_id, PresenceState.ONLINE, status_msg
- )
+ self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
# Mark user as offline
- self._set_presencestate_with_status_msg(
- user_id, PresenceState.OFFLINE, "And now here."
- )
+ self._set_presencestate_with_status_msg(PresenceState.OFFLINE, "And now here.")
def test_user_reset_online_with_no_status(self) -> None:
"""Test that if a user set again the presence manually
and no status is set, that `status_msg` is `None`.
"""
- user_id = "@test:server"
status_msg = "I'm here!"
# Mark user as online
- self._set_presencestate_with_status_msg(
- user_id, PresenceState.ONLINE, status_msg
- )
+ self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
# Mark user as online again
self.get_success(
self.presence_handler.set_state(
- UserID.from_string(user_id), {"presence": PresenceState.ONLINE}
+ self.user_id_obj, {"presence": PresenceState.ONLINE}
)
)
- state = self.get_success(
- self.presence_handler.get_state(UserID.from_string(user_id))
- )
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
# status_msg should remain even after going offline
self.assertEqual(state.state, PresenceState.ONLINE)
self.assertEqual(state.status_msg, None)
@@ -654,33 +630,27 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
"""Test that if a user set again the presence manually
and status is `None`, that `status_msg` is `None`.
"""
- user_id = "@test:server"
status_msg = "I'm here!"
# Mark user as online
- self._set_presencestate_with_status_msg(
- user_id, PresenceState.ONLINE, status_msg
- )
+ self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
# Mark user as online and `status_msg = None`
- self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
+ self._set_presencestate_with_status_msg(PresenceState.ONLINE, None)
def test_set_presence_from_syncing_not_set(self) -> None:
"""Test that presence is not set by syncing if affect_presence is false"""
- user_id = "@test:server"
status_msg = "I'm here!"
- self._set_presencestate_with_status_msg(
- user_id, PresenceState.UNAVAILABLE, status_msg
- )
+ self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
self.get_success(
- self.presence_handler.user_syncing(user_id, False, PresenceState.ONLINE)
+ self.presence_handler.user_syncing(
+ self.user_id, False, PresenceState.ONLINE
+ )
)
- state = self.get_success(
- self.presence_handler.get_state(UserID.from_string(user_id))
- )
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
# we should still be unavailable
self.assertEqual(state.state, PresenceState.UNAVAILABLE)
# and status message should still be the same
@@ -688,50 +658,34 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def test_set_presence_from_syncing_is_set(self) -> None:
"""Test that presence is set by syncing if affect_presence is true"""
- user_id = "@test:server"
status_msg = "I'm here!"
- self._set_presencestate_with_status_msg(
- user_id, PresenceState.UNAVAILABLE, status_msg
- )
+ self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
self.get_success(
- self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE)
+ self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE)
)
- state = self.get_success(
- self.presence_handler.get_state(UserID.from_string(user_id))
- )
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
# we should now be online
self.assertEqual(state.state, PresenceState.ONLINE)
def test_set_presence_from_syncing_keeps_status(self) -> None:
"""Test that presence set by syncing retains status message"""
- user_id = "@test:server"
status_msg = "I'm here!"
- self._set_presencestate_with_status_msg(
- user_id, PresenceState.UNAVAILABLE, status_msg
- )
+ self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
self.get_success(
- self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE)
+ self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE)
)
- state = self.get_success(
- self.presence_handler.get_state(UserID.from_string(user_id))
- )
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
# our status message should be the same as it was before
self.assertEqual(state.status_msg, status_msg)
@parameterized.expand([(False,), (True,)])
- @unittest.override_config(
- {
- "experimental_features": {
- "msc3026_enabled": True,
- },
- }
- )
+ @unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
def test_set_presence_from_syncing_keeps_busy(
self, test_with_workers: bool
) -> None:
@@ -741,7 +695,6 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
test_with_workers: If True, check the presence state of the user by calling
/sync against a worker, rather than the main process.
"""
- user_id = "@test:server"
status_msg = "I'm busy!"
# By default, we call /sync against the main process.
@@ -755,44 +708,39 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
)
# Set presence to BUSY
- self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg)
+ self._set_presencestate_with_status_msg(PresenceState.BUSY, status_msg)
# Perform a sync with a presence state other than busy. This should NOT change
# our presence status; we only change from busy if we explicitly set it via
# /presence/*.
self.get_success(
worker_to_sync_against.get_presence_handler().user_syncing(
- user_id, True, PresenceState.ONLINE
+ self.user_id, True, PresenceState.ONLINE
)
)
# Check against the main process that the user's presence did not change.
- state = self.get_success(
- self.presence_handler.get_state(UserID.from_string(user_id))
- )
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
# we should still be busy
self.assertEqual(state.state, PresenceState.BUSY)
def _set_presencestate_with_status_msg(
- self, user_id: str, state: str, status_msg: Optional[str]
+ self, state: str, status_msg: Optional[str]
) -> None:
"""Set a PresenceState and status_msg and check the result.
Args:
- user_id: User for that the status is to be set.
state: The new PresenceState.
status_msg: Status message that is to be set.
"""
self.get_success(
self.presence_handler.set_state(
- UserID.from_string(user_id),
+ self.user_id_obj,
{"presence": state, "status_msg": status_msg},
)
)
- new_state = self.get_success(
- self.presence_handler.get_state(UserID.from_string(user_id))
- )
+ new_state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.assertEqual(new_state.state, state)
self.assertEqual(new_state.status_msg, status_msg)
@@ -952,9 +900,6 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertEqual(upto_token, now_token)
self.assertFalse(limited)
- expected_rows = [
- (2, ("dest3", "@user3:test")),
- ]
self.assertCountEqual(rows, [])
prev_token = self.queue.get_current_token(self.instance_name)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9785dd698b..430209705e 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -446,6 +446,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertIsNone(profile)
def test_handle_user_deactivated_support_user(self) -> None:
+ """Ensure a support user doesn't get added to the user directory after deactivation."""
s_user_id = "@support:test"
self.get_success(
self.store.register_user(
@@ -453,14 +454,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
- mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
- with patch.object(
- self.store, "remove_from_user_dir", mock_remove_from_user_dir
- ):
- self.get_success(self.handler.handle_local_user_deactivated(s_user_id))
- # BUG: the correct spelling is assert_not_called, but that makes the test fail
- # and it's not clear that this is actually the behaviour we want.
- mock_remove_from_user_dir.not_called()
+ # The profile should not be in the directory.
+ profile = self.get_success(self.store._get_user_in_directory(s_user_id))
+ self.assertIsNone(profile)
+
+ # Remove the user from the directory.
+ self.get_success(self.handler.handle_local_user_deactivated(s_user_id))
+
+ # The profile should still not be in the user directory.
+ profile = self.get_success(self.store._get_user_in_directory(s_user_id))
+ self.assertIsNone(profile)
def test_handle_user_deactivated_regular_user(self) -> None:
r_user_id = "@regular:test"
diff --git a/tests/replication/test_intro_token_invalidation.py b/tests/replication/test_intro_token_invalidation.py
new file mode 100644
index 0000000000..f90678b6b1
--- /dev/null
+++ b/tests/replication/test_intro_token_invalidation.py
@@ -0,0 +1,62 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict
+
+import synapse.rest.admin._base
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+
+class IntrospectionTokenCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [synapse.rest.admin.register_servlets]
+
+ def default_config(self) -> Dict[str, Any]:
+ config = super().default_config()
+ config["disable_registration"] = True
+ config["experimental_features"] = {
+ "msc3861": {
+ "enabled": True,
+ "issuer": "some_dude",
+ "client_id": "ID",
+ "client_auth_method": "client_secret_post",
+ "client_secret": "secret",
+ }
+ }
+ return config
+
+ def test_stream_introspection_token_invalidation(self) -> None:
+ worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+ auth = worker_hs.get_auth()
+ store = self.hs.get_datastores().main
+
+ # add a token to the cache on the worker
+ auth._token_cache["open_sesame"] = "intro_token" # type: ignore[attr-defined]
+
+ # stream the invalidation from the master
+ self.get_success(
+ store.stream_introspection_token_invalidation(("open_sesame",))
+ )
+
+ # check that the cache on the worker was invalidated
+ self.assertEqual(auth._token_cache.get("open_sesame"), None) # type: ignore[attr-defined]
+
+ # test invalidating whole cache
+ for i in range(0, 5):
+ auth._token_cache[f"open_sesame_{i}"] = f"intro_token_{i}" # type: ignore[attr-defined]
+ self.assertEqual(len(auth._token_cache), 5) # type: ignore[attr-defined]
+
+ self.get_success(store.stream_introspection_token_invalidation((None,)))
+
+ self.assertEqual(len(auth._token_cache), 0) # type: ignore[attr-defined]
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 41a959b4d6..feb81844ae 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -879,6 +879,44 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self._order_test([self.admin_user, user1, user2], "creation_ts", "f")
self._order_test([user2, user1, self.admin_user], "creation_ts", "b")
+ def test_filter_admins(self) -> None:
+ """
+ Tests whether the various values of the query parameter `admins` lead to the
+ expected result set.
+ """
+
+ # Register an additional non admin user
+ self.register_user("user", "pass", admin=False)
+
+ # Query all users
+ channel = self.make_request(
+ "GET",
+ f"{self.url}",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(2, channel.json_body["total"])
+
+ # Query only admin users
+ channel = self.make_request(
+ "GET",
+ f"{self.url}?admins=true",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(1, channel.json_body["users"][0]["admin"])
+
+ # Query only non admin users
+ channel = self.make_request(
+ "GET",
+ f"{self.url}?admins=false",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertFalse(channel.json_body["users"][0]["admin"])
+
@override_config(
{
"experimental_features": {
diff --git a/tests/server.py b/tests/server.py
index 481fe34c5c..ff03d28864 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1000,8 +1000,6 @@ def setup_test_homeserver(
hs.tls_server_context_factory = Mock()
hs.setup()
- if homeserver_to_use == TestHomeServer:
- hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine):
database_pool = hs.get_datastores().databases[0]
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 9c151a5e62..7a4ecab2d5 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,7 +13,19 @@
# limitations under the License.
import datetime
-from typing import Dict, List, Tuple, Union, cast
+from typing import (
+ Collection,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Mapping,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+)
import attr
from parameterized import parameterized
@@ -38,6 +50,138 @@ from synapse.util import Clock, json_encoder
import tests.unittest
import tests.utils
+# The silly auth graph we use to test the auth difference algorithm,
+# where the top are the most recent events.
+#
+# A B
+# \ /
+# D E
+# \ |
+# ` F C
+# | /|
+# G ´ |
+# | \ |
+# H I
+# | |
+# K J
+
+AUTH_GRAPH: Dict[str, List[str]] = {
+ "a": ["e"],
+ "b": ["e"],
+ "c": ["g", "i"],
+ "d": ["f"],
+ "e": ["f"],
+ "f": ["g"],
+ "g": ["h", "i"],
+ "h": ["k"],
+ "i": ["j"],
+ "k": [],
+ "j": [],
+}
+
+DEPTH_GRAPH = {
+ "a": 7,
+ "b": 7,
+ "c": 4,
+ "d": 6,
+ "e": 6,
+ "f": 5,
+ "g": 3,
+ "h": 2,
+ "i": 2,
+ "k": 1,
+ "j": 1,
+}
+
+T = TypeVar("T")
+
+
+def get_all_topologically_sorted_orders(
+ nodes: Iterable[T],
+ graph: Mapping[T, Collection[T]],
+) -> List[List[T]]:
+ """Given a set of nodes and a graph, return all possible topological
+ orderings.
+ """
+
+ # This is implemented by Kahn's algorithm, and forking execution each time
+ # we have a choice over which node to consider next.
+
+ degree_map = {node: 0 for node in nodes}
+ reverse_graph: Dict[T, Set[T]] = {}
+
+ for node, edges in graph.items():
+ if node not in degree_map:
+ continue
+
+ for edge in set(edges):
+ if edge in degree_map:
+ degree_map[node] += 1
+
+ reverse_graph.setdefault(edge, set()).add(node)
+ reverse_graph.setdefault(node, set())
+
+ zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+
+ return _get_all_topologically_sorted_orders_inner(
+ reverse_graph, zero_degree, degree_map
+ )
+
+
+def _get_all_topologically_sorted_orders_inner(
+ reverse_graph: Dict[T, Set[T]],
+ zero_degree: List[T],
+ degree_map: Dict[T, int],
+) -> List[List[T]]:
+ new_paths = []
+
+ # Rather than only choosing *one* item from the list of nodes with zero
+ # degree, we "fork" execution and run the algorithm for each node in the
+ # zero degree.
+ for node in zero_degree:
+ new_degree_map = degree_map.copy()
+ new_zero_degree = zero_degree.copy()
+ new_zero_degree.remove(node)
+
+ for edge in reverse_graph.get(node, []):
+ if edge in new_degree_map:
+ new_degree_map[edge] -= 1
+ if new_degree_map[edge] == 0:
+ new_zero_degree.append(edge)
+
+ paths = _get_all_topologically_sorted_orders_inner(
+ reverse_graph, new_zero_degree, new_degree_map
+ )
+ for path in paths:
+ path.insert(0, node)
+
+ new_paths.extend(paths)
+
+ if not new_paths:
+ return [[]]
+
+ return new_paths
+
+
+def get_all_topologically_consistent_subsets(
+ nodes: Iterable[T],
+ graph: Mapping[T, Collection[T]],
+) -> Set[FrozenSet[T]]:
+ """Get all subsets of the graph where if node N is in the subgraph, then all
+ nodes that can reach that node (i.e. for all X there exists a path X -> N)
+ are in the subgraph.
+ """
+ all_topological_orderings = get_all_topologically_sorted_orders(nodes, graph)
+
+ graph_subsets = set()
+ for ordering in all_topological_orderings:
+ ordering.reverse()
+
+ for idx in range(len(ordering)):
+ graph_subsets.add(frozenset(ordering[:idx]))
+
+ return graph_subsets
+
@attr.s(auto_attribs=True, frozen=True, slots=True)
class _BackfillSetupInfo:
@@ -172,49 +316,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"
- # The silly auth graph we use to test the auth difference algorithm,
- # where the top are the most recent events.
- #
- # A B
- # \ /
- # D E
- # \ |
- # ` F C
- # | /|
- # G ´ |
- # | \ |
- # H I
- # | |
- # K J
-
- auth_graph: Dict[str, List[str]] = {
- "a": ["e"],
- "b": ["e"],
- "c": ["g", "i"],
- "d": ["f"],
- "e": ["f"],
- "f": ["g"],
- "g": ["h", "i"],
- "h": ["k"],
- "i": ["j"],
- "k": [],
- "j": [],
- }
-
- depth_map = {
- "a": 7,
- "b": 7,
- "c": 4,
- "d": 6,
- "e": 6,
- "f": 5,
- "g": 3,
- "h": 2,
- "i": 2,
- "k": 1,
- "j": 1,
- }
-
# Mark the room as maybe having a cover index.
def store_room(txn: LoggingTransaction) -> None:
@@ -238,9 +339,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def insert_event(txn: LoggingTransaction) -> None:
stream_ordering = 0
- for event_id in auth_graph:
+ for event_id in AUTH_GRAPH:
stream_ordering += 1
- depth = depth_map[event_id]
+ depth = DEPTH_GRAPH[event_id]
self.store.db_pool.simple_insert_txn(
txn,
@@ -260,8 +361,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.persist_events._persist_event_auth_chain_txn(
txn,
[
- cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
- for event_id in auth_graph
+ cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+ for event_id in AUTH_GRAPH
],
)
@@ -344,7 +445,51 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
room_id = self._setup_auth_chain(use_chain_cover_index)
# Now actually test that various combinations give the right result:
+ self.assert_auth_diff_is_expected(room_id)
+
+ @parameterized.expand(
+ [
+ [graph_subset]
+ for graph_subset in get_all_topologically_consistent_subsets(
+ AUTH_GRAPH, AUTH_GRAPH
+ )
+ ]
+ )
+ def test_auth_difference_partial(self, graph_subset: Collection[str]) -> None:
+ """Test that if we only have a chain cover index on a partial subset of
+ the room we still get the correct auth chain difference.
+
+ We do this by removing the chain cover index for every valid subset of the
+ graph.
+ """
+ room_id = self._setup_auth_chain(True)
+
+ for event_id in graph_subset:
+ # Remove chain cover from that event.
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="event_auth_chains",
+ keyvalues={"event_id": event_id},
+ desc="test_auth_difference_partial_remove",
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ table="event_auth_chain_to_calculate",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "type": "",
+ "state_key": "",
+ },
+ desc="test_auth_difference_partial_remove",
+ )
+ )
+
+ self.assert_auth_diff_is_expected(room_id)
+ def assert_auth_diff_is_expected(self, room_id: str) -> None:
+ """Assert the auth chain difference returns the correct answers."""
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
)
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 1277e1a865..4bcd17a6fc 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -108,3 +108,54 @@ class RetryLimiterTestCase(HomeserverTestCase):
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
+
+ def test_max_retry_interval(self) -> None:
+ """Test that `destination_max_retry_interval` setting works as expected"""
+ store = self.hs.get_datastores().main
+
+ destination_max_retry_interval_ms = (
+ self.hs.config.federation.destination_max_retry_interval_ms
+ )
+
+ self.get_success(get_retry_limiter("test_dest", self.clock, store))
+ self.pump(1)
+
+ failure_ts = self.clock.time_msec()
+
+ # Simulate reaching destination_max_retry_interval
+ self.get_success(
+ store.set_destination_retry_timings(
+ "test_dest",
+ failure_ts=failure_ts,
+ retry_last_ts=failure_ts,
+ retry_interval=destination_max_retry_interval_ms,
+ )
+ )
+
+ # Check it fails
+ self.get_failure(
+ get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
+ )
+
+ # Get past retry_interval and we can try again, and still throw an error to continue the backoff
+ self.reactor.advance(destination_max_retry_interval_ms / 1000 + 1)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
+ self.pump(1)
+ try:
+ with limiter:
+ self.pump(1)
+ raise AssertionError("argh")
+ except AssertionError:
+ pass
+
+ self.pump()
+
+ # retry_interval does not increase and stays at destination_max_retry_interval_ms
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ assert new_timings is not None
+ self.assertEqual(new_timings.retry_interval, destination_max_retry_interval_ms)
+
+ # Check it fails
+ self.get_failure(
+ get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
+ )
diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py
new file mode 100644
index 0000000000..3a97559bf0
--- /dev/null
+++ b/tests/util/test_task_scheduler.py
@@ -0,0 +1,186 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+from twisted.internet.task import deferLater
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.types import JsonMapping, ScheduledTask, TaskStatus
+from synapse.util import Clock
+from synapse.util.task_scheduler import TaskScheduler
+
+from tests import unittest
+
+
+class TestTaskScheduler(unittest.HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.task_scheduler = hs.get_task_scheduler()
+ self.task_scheduler.register_action(self._test_task, "_test_task")
+ self.task_scheduler.register_action(self._sleeping_task, "_sleeping_task")
+ self.task_scheduler.register_action(self._raising_task, "_raising_task")
+ self.task_scheduler.register_action(self._resumable_task, "_resumable_task")
+
+ async def _test_task(
+ self, task: ScheduledTask, first_launch: bool
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ # This test task will copy the parameters to the result
+ result = None
+ if task.params:
+ result = task.params
+ return (TaskStatus.COMPLETE, result, None)
+
+ def test_schedule_task(self) -> None:
+ """Schedule a task in the future with some parameters to be copied as a result and check it executed correctly.
+ Also check that it get removed after `KEEP_TASKS_FOR_MS`."""
+ timestamp = self.clock.time_msec() + 30 * 1000
+ task_id = self.get_success(
+ self.task_scheduler.schedule_task(
+ "_test_task",
+ timestamp=timestamp,
+ params={"val": 1},
+ )
+ )
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.SCHEDULED)
+ self.assertIsNone(task.result)
+
+ # The timestamp being 30s after now the task should been executed
+ # after the first scheduling loop is run
+ self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.COMPLETE)
+ assert task.result is not None
+ # The passed parameter should have been copied to the result
+ self.assertTrue(task.result.get("val") == 1)
+
+ # Let's wait for the complete task to be deleted and hence unavailable
+ self.reactor.advance((TaskScheduler.KEEP_TASKS_FOR_MS / 1000) + 1)
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ self.assertIsNone(task)
+
+ async def _sleeping_task(
+ self, task: ScheduledTask, first_launch: bool
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ # Sleep for a second
+ await deferLater(self.reactor, 1, lambda: None)
+ return TaskStatus.COMPLETE, None, None
+
+ def test_schedule_lot_of_tasks(self) -> None:
+ """Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior."""
+ timestamp = self.clock.time_msec() + 30 * 1000
+ task_ids = []
+ for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1):
+ task_ids.append(
+ self.get_success(
+ self.task_scheduler.schedule_task(
+ "_sleeping_task",
+ timestamp=timestamp,
+ params={"val": i},
+ )
+ )
+ )
+
+ # The timestamp being 30s after now the task should been executed
+ # after the first scheduling loop is run
+ self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+
+ # This is to give the time to the sleeping tasks to finish
+ self.reactor.advance(1)
+
+ # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
+ # is still scheduled.
+ tasks = [
+ self.get_success(self.task_scheduler.get_task(task_id))
+ for task_id in task_ids
+ ]
+
+ self.assertEquals(
+ len(
+ [t for t in tasks if t is not None and t.status == TaskStatus.COMPLETE]
+ ),
+ TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
+ )
+
+ scheduled_tasks = [
+ t for t in tasks if t is not None and t.status == TaskStatus.SCHEDULED
+ ]
+ self.assertEquals(len(scheduled_tasks), 1)
+
+ self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+ self.reactor.advance(1)
+
+ # Check that the last task has been properly executed after the next scheduler loop run
+ prev_scheduled_task = self.get_success(
+ self.task_scheduler.get_task(scheduled_tasks[0].id)
+ )
+ assert prev_scheduled_task is not None
+ self.assertEquals(
+ prev_scheduled_task.status,
+ TaskStatus.COMPLETE,
+ )
+
+ async def _raising_task(
+ self, task: ScheduledTask, first_launch: bool
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ raise Exception("raising")
+
+ def test_schedule_raising_task(self) -> None:
+ """Schedule a task raising an exception and check it runs to failure and report exception content."""
+ task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task"))
+
+ self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.FAILED)
+ self.assertEqual(task.error, "raising")
+
+ async def _resumable_task(
+ self, task: ScheduledTask, first_launch: bool
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ if task.result and "in_progress" in task.result:
+ return TaskStatus.COMPLETE, {"success": True}, None
+ else:
+ await self.task_scheduler.update_task(task.id, result={"in_progress": True})
+ # Await forever to simulate an aborted task because of a restart
+ await deferLater(self.reactor, 2**16, lambda: None)
+ # This should never been called
+ return TaskStatus.ACTIVE, None, None
+
+ def test_schedule_resumable_task(self) -> None:
+ """Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart."""
+ task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task"))
+
+ self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.ACTIVE)
+
+ # Simulate a synapse restart by emptying the list of running tasks
+ self.task_scheduler._running_tasks = set()
+ self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.COMPLETE)
+ assert task.result is not None
+ self.assertTrue(task.result.get("success"))
|