diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 765258c47a..69a4e9413b 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -46,15 +46,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
"was: %r" % (config.key.macaroon_secret_key,)
)
- config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ assert config2 is not None
self.assertTrue(
- hasattr(config.key, "macaroon_secret_key"),
+ hasattr(config2.key, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key",
)
- if len(config.key.macaroon_secret_key) < 5:
+ if len(config2.key.macaroon_secret_key) < 5:
self.fail(
"Want macaroon secret key to be string of at least length 5,"
- "was: %r" % (config.key.macaroon_secret_key,)
+ "was: %r" % (config2.key.macaroon_secret_key,)
)
def test_load_succeeds_if_macaroon_secret_key_missing(self):
@@ -62,6 +63,9 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config1 = HomeServerConfig.load_config("", ["-c", self.config_file])
config2 = HomeServerConfig.load_config("", ["-c", self.config_file])
config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ assert config1 is not None
+ assert config2 is not None
+ assert config3 is not None
self.assertEqual(
config1.key.macaroon_secret_key, config2.key.macaroon_secret_key
)
@@ -78,14 +82,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config = HomeServerConfig.load_config("", ["-c", self.config_file])
self.assertFalse(config.registration.enable_registration)
- config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
- self.assertFalse(config.registration.enable_registration)
+ config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ assert config2 is not None
+ self.assertFalse(config2.registration.enable_registration)
# Check that either config value is clobbered by the command line.
- config = HomeServerConfig.load_or_generate_config(
+ config3 = HomeServerConfig.load_or_generate_config(
"", ["-c", self.config_file, "--enable-registration"]
)
- self.assertTrue(config.registration.enable_registration)
+ assert config3 is not None
+ self.assertTrue(config3.registration.enable_registration)
def test_stats_enabled(self):
self.generate_config_and_remove_lines_containing("enable_metrics")
@@ -94,3 +100,12 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
# The default Metrics Flags are off by default.
config = HomeServerConfig.load_config("", ["-c", self.config_file])
self.assertFalse(config.metrics.metrics_flags.known_servers)
+
+ def test_depreciated_identity_server_flag_throws_error(self):
+ self.generate_config()
+ # Needed to ensure that actual key/value pair added below don't end up on a line with a comment
+ self.add_lines_to_config([" "])
+ # Check that presence of "trust_identity_server_for_password" throws config error
+ self.add_lines_to_config(["trust_identity_server_for_password_resets: true"])
+ with self.assertRaises(ConfigError):
+ HomeServerConfig.load_config("", ["-c", self.config_file])
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index cbecc1c20f..17a9fb63a1 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -1,4 +1,4 @@
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2021 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@ import signedjson.sign
from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
+from twisted.internet import defer
from twisted.internet.defer import Deferred, ensureDeferred
from synapse.api.errors import SynapseError
@@ -40,7 +41,7 @@ from synapse.storage.keys import FetchKeyResult
from tests import unittest
from tests.test_utils import make_awaitable
-from tests.unittest import logcontext_clean
+from tests.unittest import logcontext_clean, override_config
class MockPerspectiveServer:
@@ -197,7 +198,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# self.assertFalse(d.called)
self.get_success(d)
- def test_verify_for_server_locally(self):
+ def test_verify_for_local_server(self):
"""Ensure that locally signed JSON can be verified without fetching keys
over federation
"""
@@ -209,6 +210,56 @@ class KeyringTestCase(unittest.HomeserverTestCase):
d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
self.get_success(d)
+ OLD_KEY = signedjson.key.generate_signing_key("old")
+
+ @override_config(
+ {
+ "old_signing_keys": {
+ f"{OLD_KEY.alg}:{OLD_KEY.version}": {
+ "key": encode_verify_key_base64(OLD_KEY.verify_key),
+ "expired_ts": 1000,
+ }
+ }
+ }
+ )
+ def test_verify_for_local_server_old_key(self):
+ """Can also use keys in old_signing_keys for verification"""
+ json1 = {}
+ signedjson.sign.sign_json(json1, self.hs.hostname, self.OLD_KEY)
+
+ kr = keyring.Keyring(self.hs)
+ d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
+ self.get_success(d)
+
+ def test_verify_for_local_server_unknown_key(self):
+ """Local keys that we no longer have should be fetched via the fetcher"""
+
+ # the key we'll sign things with (nb, not known to the Keyring)
+ key2 = signedjson.key.generate_signing_key("2")
+
+ # set up a mock fetcher which will return the key
+ async def get_keys(
+ server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+ ) -> Dict[str, FetchKeyResult]:
+ self.assertEqual(server_name, self.hs.hostname)
+ self.assertEqual(key_ids, [get_key_id(key2)])
+
+ return {get_key_id(key2): FetchKeyResult(get_verify_key(key2), 1200)}
+
+ mock_fetcher = Mock()
+ mock_fetcher.get_keys = Mock(side_effect=get_keys)
+ kr = keyring.Keyring(
+ self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
+ )
+
+ # sign the json
+ json1 = {}
+ signedjson.sign.sign_json(json1, self.hs.hostname, key2)
+
+ # ... and check we can verify it.
+ d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
+ self.get_success(d)
+
def test_verify_json_for_server_with_null_valid_until_ms(self):
"""Tests that we correctly handle key requests for keys we've stored
with a null `ts_valid_until_ms`
@@ -527,6 +578,76 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
+ def test_get_multiple_keys_from_perspectives(self):
+ """Check that we can correctly request multiple keys for the same server"""
+
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
+ SERVER_NAME = "server2"
+
+ testkey1 = signedjson.key.generate_signing_key("ver1")
+ testverifykey1 = signedjson.key.get_verify_key(testkey1)
+ testverifykey1_id = "ed25519:ver1"
+
+ testkey2 = signedjson.key.generate_signing_key("ver2")
+ testverifykey2 = signedjson.key.get_verify_key(testkey2)
+ testverifykey2_id = "ed25519:ver2"
+
+ VALID_UNTIL_TS = 200 * 1000
+
+ response1 = self.build_perspectives_response(
+ SERVER_NAME,
+ testkey1,
+ VALID_UNTIL_TS,
+ )
+ response2 = self.build_perspectives_response(
+ SERVER_NAME,
+ testkey2,
+ VALID_UNTIL_TS,
+ )
+
+ async def post_json(destination, path, data, **kwargs):
+ self.assertEqual(destination, self.mock_perspective_server.server_name)
+ self.assertEqual(path, "/_matrix/key/v2/query")
+
+ # check that the request is for the expected keys
+ q = data["server_keys"]
+
+ self.assertEqual(
+ list(q[SERVER_NAME].keys()), [testverifykey1_id, testverifykey2_id]
+ )
+ return {"server_keys": [response1, response2]}
+
+ self.http_client.post_json.side_effect = post_json
+
+ # fire off two separate requests; they should get merged together into a
+ # single HTTP hit.
+ request1_d = defer.ensureDeferred(
+ fetcher.get_keys(SERVER_NAME, [testverifykey1_id], 0)
+ )
+ request2_d = defer.ensureDeferred(
+ fetcher.get_keys(SERVER_NAME, [testverifykey2_id], 0)
+ )
+
+ keys1 = self.get_success(request1_d)
+ self.assertIn(testverifykey1_id, keys1)
+ k = keys1[testverifykey1_id]
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey1)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
+
+ keys2 = self.get_success(request2_d)
+ self.assertIn(testverifykey2_id, keys2)
+ k = keys2[testverifykey2_id]
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey2)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver2")
+
+ # finally, ensure that only one request was sent
+ self.assertEqual(self.http_client.post_json.call_count, 1)
+
def test_get_perspectives_own_key(self):
"""Check that we can get the perspectives server's own keys
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
new file mode 100644
index 0000000000..0b19159961
--- /dev/null
+++ b/tests/federation/transport/test_client.py
@@ -0,0 +1,50 @@
+import json
+
+from synapse.api.room_versions import RoomVersions
+from synapse.federation.transport.client import SendJoinParser
+
+from tests.unittest import TestCase
+
+
+class SendJoinParserTestCase(TestCase):
+ def test_two_writes(self) -> None:
+ """Test that the parser can sensibly deserialise an input given in two slices."""
+ parser = SendJoinParser(RoomVersions.V1, True)
+ parent_event = {
+ "content": {
+ "see_room_version_spec": "The event format changes depending on the room version."
+ },
+ "event_id": "$authparent",
+ "room_id": "!somewhere:example.org",
+ "type": "m.room.minimal_pdu",
+ }
+ state = {
+ "content": {
+ "see_room_version_spec": "The event format changes depending on the room version."
+ },
+ "event_id": "$DoNotThinkAboutTheEvent",
+ "room_id": "!somewhere:example.org",
+ "type": "m.room.minimal_pdu",
+ }
+ response = [
+ 200,
+ {
+ "auth_chain": [parent_event],
+ "origin": "matrix.org",
+ "state": [state],
+ },
+ ]
+ serialised_response = json.dumps(response).encode()
+
+ # Send data to the parser
+ parser.write(serialised_response[:100])
+ parser.write(serialised_response[100:])
+
+ # Retrieve the parsed SendJoinResponse
+ parsed_response = parser.finish()
+
+ # Sanity check the parsing gave us sensible data.
+ self.assertEqual(len(parsed_response.auth_events), 1, parsed_response)
+ self.assertEqual(len(parsed_response.state), 1, parsed_response)
+ self.assertEqual(parsed_response.event_dict, {}, parsed_response)
+ self.assertIsNone(parsed_response.event, parsed_response)
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 12857053e7..72e176da75 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -116,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None
)
)
@@ -134,7 +134,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.get_failure(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None
),
ResourceLimitError,
@@ -162,7 +162,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# If not in monthly active cohort
self.get_failure(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None
),
ResourceLimitError,
@@ -179,7 +179,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.clock.time_msec())
)
self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None
)
)
@@ -197,7 +197,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
# Ensure does not raise exception
self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None
)
)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index be008227df..0ea4e753e2 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2021 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,13 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from unittest.mock import Mock
import synapse.api.errors
import synapse.rest.admin
from synapse.api.constants import EventTypes
-from synapse.config.room_directory import RoomDirectoryConfig
from synapse.rest.client import directory, login, room
from synapse.types import RoomAlias, create_requester
@@ -394,22 +393,15 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def prepare(self, reactor, clock, hs):
- # We cheekily override the config to add custom alias creation rules
- config = {}
+ def default_config(self):
+ config = super().default_config()
+
+ # Add custom alias creation rules to the config.
config["alias_creation_rules"] = [
{"user_id": "*", "alias": "#unofficial_*", "action": "allow"}
]
- config["room_list_publication_rules"] = []
- rd_config = RoomDirectoryConfig()
- rd_config.read_config(config)
-
- self.hs.config.roomdirectory.is_alias_creation_allowed = (
- rd_config.is_alias_creation_allowed
- )
-
- return hs
+ return config
def test_denied(self):
room_id = self.helper.create_room_as(self.user_id)
@@ -417,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
b"directory/room/%23test%3Atest",
- ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
+ {"room_id": room_id},
)
self.assertEquals(403, channel.code, channel.result)
@@ -427,14 +419,35 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
b"directory/room/%23unofficial_test%3Atest",
- ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
+ {"room_id": room_id},
)
self.assertEquals(200, channel.code, channel.result)
+ def test_denied_during_creation(self):
+ """A room alias that is not allowed should be rejected during creation."""
+ # Invalid room alias.
+ self.helper.create_room_as(
+ self.user_id,
+ expect_code=403,
+ extra_content={"room_alias_name": "foo"},
+ )
-class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
- data = {"room_alias_name": "unofficial_test"}
+ def test_allowed_during_creation(self):
+ """A valid room alias should be allowed during creation."""
+ room_id = self.helper.create_room_as(
+ self.user_id,
+ extra_content={"room_alias_name": "unofficial_test"},
+ )
+ channel = self.make_request(
+ "GET",
+ b"directory/room/%23unofficial_test%3Atest",
+ )
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(channel.json_body["room_id"], room_id)
+
+
+class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -443,27 +456,30 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
]
hijack_auth = False
- def prepare(self, reactor, clock, hs):
- self.allowed_user_id = self.register_user("allowed", "pass")
- self.allowed_access_token = self.login("allowed", "pass")
+ data = {"room_alias_name": "unofficial_test"}
+ allowed_localpart = "allowed"
- self.denied_user_id = self.register_user("denied", "pass")
- self.denied_access_token = self.login("denied", "pass")
+ def default_config(self):
+ config = super().default_config()
- # This time we add custom room list publication rules
- config = {}
- config["alias_creation_rules"] = []
+ # Add custom room list publication rules to the config.
config["room_list_publication_rules"] = [
+ {
+ "user_id": "@" + self.allowed_localpart + "*",
+ "alias": "#unofficial_*",
+ "action": "allow",
+ },
{"user_id": "*", "alias": "*", "action": "deny"},
- {"user_id": self.allowed_user_id, "alias": "*", "action": "allow"},
]
- rd_config = RoomDirectoryConfig()
- rd_config.read_config(config)
+ return config
- self.hs.config.roomdirectory.is_publishing_room_allowed = (
- rd_config.is_publishing_room_allowed
- )
+ def prepare(self, reactor, clock, hs):
+ self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
+ self.allowed_access_token = self.login(self.allowed_localpart, "pass")
+
+ self.denied_user_id = self.register_user("denied", "pass")
+ self.denied_access_token = self.login("denied", "pass")
return hs
@@ -505,10 +521,23 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
self.allowed_user_id,
tok=self.allowed_access_token,
extra_content=self.data,
- is_public=False,
+ is_public=True,
expect_code=200,
)
+ def test_denied_publication_with_invalid_alias(self):
+ """
+ Try to create a room, register an alias for it, and publish it,
+ as a user WITH permission to publish rooms.
+ """
+ self.helper.create_room_as(
+ self.allowed_user_id,
+ tok=self.allowed_access_token,
+ extra_content={"room_alias_name": "foo"},
+ is_public=True,
+ expect_code=403,
+ )
+
def test_can_create_as_private_room_after_rejection(self):
"""
After failing to publish a room with an alias as a user without publish permission,
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 0c3b86fda9..f0723892e4 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -162,6 +162,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "key1"}
+ fallback_key2 = {"alg1:k2": "key2"}
otk = {"alg1:k2": "key2"}
# we shouldn't have any unused fallback keys yet
@@ -213,6 +214,35 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
+ # re-uploading the same fallback key should still result in no unused fallback
+ # keys
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"org.matrix.msc2732.fallback_keys": fallback_key},
+ )
+ )
+
+ res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, [])
+
+ # uploading a new fallback key should result in an unused fallback key
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"org.matrix.msc2732.fallback_keys": fallback_key2},
+ )
+ )
+
+ res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, ["alg1"])
+
# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
self.get_success(
@@ -238,7 +268,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
res,
- {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)
def test_replace_master_key(self):
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index 86beb8ff08..e85d112ecc 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -14,6 +14,8 @@
from typing import Any, Iterable, List, Optional, Tuple
from unittest import mock
+from twisted.internet.defer import ensureDeferred
+
from synapse.api.constants import (
EventContentFields,
EventTypes,
@@ -30,7 +32,7 @@ from synapse.handlers.room_summary import _RoomEntry, child_events_comparison_ke
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, UserID, create_requester
from tests import unittest
@@ -247,7 +249,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_rooms(result, expected)
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
@@ -261,7 +263,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
expected = [(self.space, [self.room]), (self.room, ())]
self._assert_rooms(result, expected)
- result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space)
+ )
self._assert_hierarchy(result, expected)
# If the space is made invite-only, it should no longer be viewable.
@@ -272,7 +276,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
tok=self.token,
)
self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
- self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
+ self.get_failure(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space),
+ AuthError,
+ )
# If the space is made world-readable it should return a result.
self.helper.send_state(
@@ -284,7 +291,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.handler.get_space_summary(user2, self.space))
self._assert_rooms(result, expected)
- result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space)
+ )
self._assert_hierarchy(result, expected)
# Make it not world-readable again and confirm it results in an error.
@@ -295,7 +304,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
tok=self.token,
)
self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
- self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
+ self.get_failure(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space),
+ AuthError,
+ )
# Join the space and results should be returned.
self.helper.invite(self.space, targ=user2, tok=self.token)
@@ -303,7 +315,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.handler.get_space_summary(user2, self.space))
self._assert_rooms(result, expected)
- result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space)
+ )
self._assert_hierarchy(result, expected)
# Attempting to view an unknown room returns the same error.
@@ -312,10 +326,67 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
AuthError,
)
self.get_failure(
- self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname),
+ self.handler.get_room_hierarchy(
+ create_requester(user2), "#not-a-space:" + self.hs.hostname
+ ),
AuthError,
)
+ def test_room_hierarchy_cache(self) -> None:
+ """In-flight room hierarchy requests are deduplicated."""
+ # Run two `get_room_hierarchy` calls up until they block.
+ deferred1 = ensureDeferred(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ deferred2 = ensureDeferred(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+
+ # Complete the two calls.
+ result1 = self.get_success(deferred1)
+ result2 = self.get_success(deferred2)
+
+ # Both `get_room_hierarchy` calls should return the same result.
+ expected = [(self.space, [self.room]), (self.room, ())]
+ self._assert_hierarchy(result1, expected)
+ self._assert_hierarchy(result2, expected)
+ self.assertIs(result1, result2)
+
+ # A subsequent `get_room_hierarchy` call should not reuse the result.
+ result3 = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self._assert_hierarchy(result3, expected)
+ self.assertIsNot(result1, result3)
+
+ def test_room_hierarchy_cache_sharing(self) -> None:
+ """Room hierarchy responses for different users are not shared."""
+ user2 = self.register_user("user2", "pass")
+
+ # Make the room within the space invite-only.
+ self.helper.send_state(
+ self.room,
+ event_type=EventTypes.JoinRules,
+ body={"join_rule": JoinRules.INVITE},
+ tok=self.token,
+ )
+
+ # Run two `get_room_hierarchy` calls for different users up until they block.
+ deferred1 = ensureDeferred(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ deferred2 = ensureDeferred(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space)
+ )
+
+ # Complete the two calls.
+ result1 = self.get_success(deferred1)
+ result2 = self.get_success(deferred2)
+
+ # The `get_room_hierarchy` calls should return different results.
+ self._assert_hierarchy(result1, [(self.space, [self.room]), (self.room, ())])
+ self._assert_hierarchy(result2, [(self.space, [self.room])])
+
def _create_room_with_join_rule(
self, join_rule: str, room_version: Optional[str] = None, **extra_content
) -> str:
@@ -410,7 +481,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
]
self._assert_rooms(result, expected)
- result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space)
+ )
self._assert_hierarchy(result, expected)
def test_complex_space(self):
@@ -452,7 +525,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_rooms(result, expected)
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
@@ -467,7 +540,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
room_ids.append(self.room)
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space, limit=7)
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, limit=7
+ )
)
# The result should have the space and all of the links, plus some of the
# rooms and a pagination token.
@@ -479,7 +554,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# Check the next page.
result = self.get_success(
self.handler.get_room_hierarchy(
- self.user, self.space, limit=5, from_token=result["next_batch"]
+ create_requester(self.user),
+ self.space,
+ limit=5,
+ from_token=result["next_batch"],
)
)
# The result should have the space and the room in it, along with a link
@@ -499,20 +577,22 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
room_ids.append(self.room)
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space, limit=7)
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, limit=7
+ )
)
self.assertIn("next_batch", result)
# Changing the room ID, suggested-only, or max-depth causes an error.
self.get_failure(
self.handler.get_room_hierarchy(
- self.user, self.room, from_token=result["next_batch"]
+ create_requester(self.user), self.room, from_token=result["next_batch"]
),
SynapseError,
)
self.get_failure(
self.handler.get_room_hierarchy(
- self.user,
+ create_requester(self.user),
self.space,
suggested_only=True,
from_token=result["next_batch"],
@@ -521,14 +601,19 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self.get_failure(
self.handler.get_room_hierarchy(
- self.user, self.space, max_depth=0, from_token=result["next_batch"]
+ create_requester(self.user),
+ self.space,
+ max_depth=0,
+ from_token=result["next_batch"],
),
SynapseError,
)
# An invalid token is ignored.
self.get_failure(
- self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"),
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, from_token="foo"
+ ),
SynapseError,
)
@@ -554,14 +639,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# Test just the space itself.
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space, max_depth=0)
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, max_depth=0
+ )
)
expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])]
self._assert_hierarchy(result, expected)
# A single additional layer.
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space, max_depth=1)
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, max_depth=1
+ )
)
expected += [
(rooms[0], ()),
@@ -571,7 +660,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# A few layers.
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space, max_depth=3)
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, max_depth=3
+ )
)
expected += [
(rooms[1], ()),
@@ -602,7 +693,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_rooms(result, expected)
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
@@ -684,7 +775,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
new=summarize_remote_room_hierarchy,
):
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
@@ -851,7 +942,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
new=summarize_remote_room_hierarchy,
):
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
@@ -909,7 +1000,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
new=summarize_remote_room_hierarchy,
):
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
index 1f9a2f9b1d..c8cc21cadd 100644
--- a/tests/http/test_endpoint.py
+++ b/tests/http/test_endpoint.py
@@ -36,8 +36,11 @@ class ServerNameTestCase(unittest.TestCase):
"localhost:http", # non-numeric port
"1234]", # smells like ipv6 literal but isn't
"[1234",
+ "[1.2.3.4]",
"underscore_.com",
"percent%65.com",
+ "newline.com\n",
+ ".empty-label.com",
"1234:5678:80", # too many colons
]
for i in test_data:
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 90f800e564..f8cba7b645 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -128,6 +128,7 @@ class EmailPusherTests(HomeserverTestCase):
)
self.auth_handler = hs.get_auth_handler()
+ self.store = hs.get_datastore()
def test_need_validated_email(self):
"""Test that we can only add an email pusher if the user has validated
@@ -408,13 +409,7 @@ class EmailPusherTests(HomeserverTestCase):
self.hs.get_datastore().db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(
- self.hs.get_datastore().db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.hs.get_datastore().db_pool.updates.do_next_background_update(100),
- by=0.1,
- )
+ self.wait_for_background_updates()
# Check that all pushers with unlinked addresses were deleted
pushers = self.get_success(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 0a6e4795ee..596ba5a0c9 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -17,6 +17,7 @@ from unittest.mock import patch
from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
from synapse.rest.client import login, room, sync
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
@@ -193,7 +194,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
#
# Worker2's event stream position will not advance until we call
# __aexit__ again.
- actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
+ worker_store2 = worker_hs2.get_datastore()
+ assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator)
+
+ actx = worker_store2._stream_id_gen.get_next()
self.get_success(actx.__aenter__())
response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 78c48db552..62f242baf6 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -11,10 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
+from typing import Collection
+
+from parameterized import parameterized
import synapse.rest.admin
+from synapse.api.errors import Codes
from synapse.rest.client import login
from synapse.server import HomeServer
+from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest
@@ -30,6 +36,60 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
+ @parameterized.expand(
+ [
+ ("GET", "/_synapse/admin/v1/background_updates/enabled"),
+ ("POST", "/_synapse/admin/v1/background_updates/enabled"),
+ ("GET", "/_synapse/admin/v1/background_updates/status"),
+ ("POST", "/_synapse/admin/v1/background_updates/start_job"),
+ ]
+ )
+ def test_requester_is_no_admin(self, method: str, url: str):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ self.register_user("user", "pass", admin=False)
+ other_user_tok = self.login("user", "pass")
+
+ channel = self.make_request(
+ method,
+ url,
+ content={},
+ access_token=other_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+ url = "/_synapse/admin/v1/background_updates/start_job"
+
+ # empty content
+ channel = self.make_request(
+ "POST",
+ url,
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ # job_name invalid
+ channel = self.make_request(
+ "POST",
+ url,
+ content={"job_name": "unknown"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
def _register_bg_update(self):
"Adds a bg update but doesn't start it"
@@ -60,7 +120,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Background updates should be enabled, but none should be running.
self.assertDictEqual(
@@ -75,14 +135,14 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self._register_bg_update()
self.store.db_pool.updates.start_doing_background_updates()
- self.reactor.pump([1.0, 1.0])
+ self.reactor.pump([1.0, 1.0, 1.0])
channel = self.make_request(
"GET",
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Background updates should be enabled, and one should be running.
self.assertDictEqual(
@@ -91,9 +151,11 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 1000.0,
- "total_item_count": 100,
+ "total_item_count": (
+ BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": True,
@@ -114,7 +176,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/enabled",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
# Disable the BG updates
@@ -124,7 +186,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": False})
# Advance a bit and get the current status, note this will finish the in
@@ -137,16 +199,18 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(
channel.json_body,
{
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 1000.0,
- "total_item_count": 100,
+ "total_item_count": (
+ BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": False,
@@ -162,7 +226,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# There should be no change from the previous /status response.
self.assertDictEqual(
@@ -171,9 +235,11 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 1000.0,
- "total_item_count": 100,
+ "total_item_count": (
+ BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": False,
@@ -188,7 +254,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
@@ -199,7 +265,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Background updates should be enabled and making progress.
self.assertDictEqual(
@@ -208,11 +274,92 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 2000.0,
- "total_item_count": 200,
+ "total_item_count": (
+ 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": True,
},
)
+
+ @parameterized.expand(
+ [
+ ("populate_stats_process_rooms", ["populate_stats_process_rooms"]),
+ (
+ "regenerate_directory",
+ [
+ "populate_user_directory_createtables",
+ "populate_user_directory_process_rooms",
+ "populate_user_directory_process_users",
+ "populate_user_directory_cleanup",
+ ],
+ ),
+ ]
+ )
+ def test_start_backround_job(self, job_name: str, updates: Collection[str]):
+ """
+ Test that background updates add to database and be processed.
+
+ Args:
+ job_name: name of the job to call with API
+ updates: collection of background updates to be started
+ """
+
+ # no background update is waiting
+ self.assertTrue(
+ self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ )
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/background_updates/start_job",
+ content={"job_name": job_name},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+ # test that each background update is waiting now
+ for update in updates:
+ self.assertFalse(
+ self.get_success(
+ self.store.db_pool.updates.has_completed_background_update(update)
+ )
+ )
+
+ self.wait_for_background_updates()
+
+ # background updates are done
+ self.assertTrue(
+ self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ )
+ )
+
+ def test_start_backround_job_twice(self):
+ """Test that add a background update twice return an error."""
+
+ # add job to database
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ table="background_updates",
+ values={
+ "update_name": "populate_stats_process_rooms",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/background_updates/start_job",
+ content={"job_name": "populate_stats_process_rooms"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index b48fc12e5f..07077aff78 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -2226,6 +2226,234 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
)
+class BlockRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self._store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = "/_synapse/admin/v1/rooms/%s/block"
+
+ @parameterized.expand([("PUT",), ("GET",)])
+ def test_requester_is_no_admin(self, method: str):
+ """If the user is not a server admin, an error 403 is returned."""
+
+ channel = self.make_request(
+ method,
+ self.url % self.room_id,
+ content={},
+ access_token=self.other_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ @parameterized.expand([("PUT",), ("GET",)])
+ def test_room_is_not_valid(self, method: str):
+ """Check that invalid room names, return an error 400."""
+
+ channel = self.make_request(
+ method,
+ self.url % "invalidroom",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "invalidroom is not a legal room ID",
+ channel.json_body["error"],
+ )
+
+ def test_block_is_not_valid(self):
+ """If parameter `block` is not valid, return an error."""
+
+ # `block` is not valid
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={"block": "NotBool"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ # `block` is not set
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ # no content is send
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
+
+ def test_block_room(self):
+ """Test that block a room is successful."""
+
+ def _request_and_test_block_room(room_id: str) -> None:
+ self._is_blocked(room_id, expect=False)
+ channel = self.make_request(
+ "PUT",
+ self.url % room_id,
+ content={"block": True},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertTrue(channel.json_body["block"])
+ self._is_blocked(room_id, expect=True)
+
+ # known internal room
+ _request_and_test_block_room(self.room_id)
+
+ # unknown internal room
+ _request_and_test_block_room("!unknown:test")
+
+ # unknown remote room
+ _request_and_test_block_room("!unknown:remote")
+
+ def test_block_room_twice(self):
+ """Test that block a room that is already blocked is successful."""
+
+ self._is_blocked(self.room_id, expect=False)
+ for _ in range(2):
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={"block": True},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertTrue(channel.json_body["block"])
+ self._is_blocked(self.room_id, expect=True)
+
+ def test_unblock_room(self):
+ """Test that unblock a room is successful."""
+
+ def _request_and_test_unblock_room(room_id: str) -> None:
+ self._block_room(room_id)
+
+ channel = self.make_request(
+ "PUT",
+ self.url % room_id,
+ content={"block": False},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body["block"])
+ self._is_blocked(room_id, expect=False)
+
+ # known internal room
+ _request_and_test_unblock_room(self.room_id)
+
+ # unknown internal room
+ _request_and_test_unblock_room("!unknown:test")
+
+ # unknown remote room
+ _request_and_test_unblock_room("!unknown:remote")
+
+ def test_unblock_room_twice(self):
+ """Test that unblock a room that is not blocked is successful."""
+
+ self._block_room(self.room_id)
+ for _ in range(2):
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={"block": False},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body["block"])
+ self._is_blocked(self.room_id, expect=False)
+
+ def test_get_blocked_room(self):
+ """Test get status of a blocked room"""
+
+ def _request_blocked_room(room_id: str) -> None:
+ self._block_room(room_id)
+
+ channel = self.make_request(
+ "GET",
+ self.url % room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertTrue(channel.json_body["block"])
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+
+ # known internal room
+ _request_blocked_room(self.room_id)
+
+ # unknown internal room
+ _request_blocked_room("!unknown:test")
+
+ # unknown remote room
+ _request_blocked_room("!unknown:remote")
+
+ def test_get_unblocked_room(self):
+ """Test get status of a unblocked room"""
+
+ def _request_unblocked_room(room_id: str) -> None:
+ self._is_blocked(room_id, expect=False)
+
+ channel = self.make_request(
+ "GET",
+ self.url % room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body["block"])
+ self.assertNotIn("user_id", channel.json_body)
+
+ # known internal room
+ _request_unblocked_room(self.room_id)
+
+ # unknown internal room
+ _request_unblocked_room("!unknown:test")
+
+ # unknown remote room
+ _request_unblocked_room("!unknown:remote")
+
+ def _is_blocked(self, room_id: str, expect: bool = True) -> None:
+ """Assert that the room is blocked or not"""
+ d = self._store.is_room_blocked(room_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertIsNone(self.get_success(d))
+
+ def _block_room(self, room_id: str) -> None:
+ """Block a room in database"""
+ self.get_success(self._store.block_room(room_id, self.other_user))
+ self._is_blocked(room_id, expect=True)
+
+
PURGE_TABLES = [
"current_state_events",
"event_backward_extremities",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index c9fe0f06c2..5011e54563 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1169,14 +1169,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# regardless of whether password login or SSO is allowed
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.admin_user, device_id=None, valid_until_ms=None
)
)
self.other_user = self.register_user("user", "pass", displayname="User")
self.other_user_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.other_user, device_id=None, valid_until_ms=None
)
)
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index e2fcbdc63a..d8a94f4c12 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
from typing import Optional, Union
from twisted.internet.defer import succeed
@@ -513,12 +514,26 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
+ def use_refresh_token(self, refresh_token: str) -> FakeChannel:
+ """
+ Helper that makes a request to use a refresh token.
+ """
+ return self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
+ {"refresh_token": refresh_token},
+ )
+
def test_login_issue_refresh_token(self):
"""
A login response should include a refresh_token only if asked.
"""
# Test login
- body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ }
login_without_refresh = self.make_request(
"POST", "/_matrix/client/r0/login", body
@@ -528,8 +543,8 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
login_with_refresh = self.make_request(
"POST",
- "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
- body,
+ "/_matrix/client/r0/login",
+ {"org.matrix.msc2918.refresh_token": True, **body},
)
self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result)
self.assertIn("refresh_token", login_with_refresh.json_body)
@@ -555,11 +570,12 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
register_with_refresh = self.make_request(
"POST",
- "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true",
+ "/_matrix/client/r0/register",
{
"username": "test3",
"password": self.user_pass,
"auth": {"type": LoginType.DUMMY},
+ "org.matrix.msc2918.refresh_token": True,
},
)
self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result)
@@ -570,10 +586,15 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
"""
A refresh token can be used to issue a new access token.
"""
- body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ "org.matrix.msc2918.refresh_token": True,
+ }
login_response = self.make_request(
"POST",
- "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
+ "/_matrix/client/r0/login",
body,
)
self.assertEqual(login_response.code, 200, login_response.result)
@@ -598,15 +619,20 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
refresh_response.json_body["refresh_token"],
)
- @override_config({"access_token_lifetime": "1m"})
- def test_refresh_token_expiration(self):
+ @override_config({"refreshable_access_token_lifetime": "1m"})
+ def test_refreshable_access_token_expiration(self):
"""
The access token should have some time as specified in the config.
"""
- body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ "org.matrix.msc2918.refresh_token": True,
+ }
login_response = self.make_request(
"POST",
- "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
+ "/_matrix/client/r0/login",
body,
)
self.assertEqual(login_response.code, 200, login_response.result)
@@ -623,6 +649,128 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.assertApproximates(
refresh_response.json_body["expires_in_ms"], 60 * 1000, 100
)
+ access_token = refresh_response.json_body["access_token"]
+
+ # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry)
+ self.reactor.advance(59.0)
+ # Check that our token is valid
+ self.assertEqual(
+ self.make_request(
+ "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
+ ).code,
+ HTTPStatus.OK,
+ )
+
+ # Advance 2 more seconds (just past the time of expiry)
+ self.reactor.advance(2.0)
+ # Check that our token is invalid
+ self.assertEqual(
+ self.make_request(
+ "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
+ ).code,
+ HTTPStatus.UNAUTHORIZED,
+ )
+
+ @override_config(
+ {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
+ )
+ def test_refresh_token_expiry(self):
+ """
+ The refresh token can be configured to have a limited lifetime.
+ When that lifetime has ended, the refresh token can no longer be used to
+ refresh the session.
+ """
+
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ "org.matrix.msc2918.refresh_token": True,
+ }
+ login_response = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login",
+ body,
+ )
+ self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
+ refresh_token1 = login_response.json_body["refresh_token"]
+
+ # Advance 119 seconds in the future (just shy of 2 minutes)
+ self.reactor.advance(119.0)
+
+ # Refresh our session. The refresh token should still JUST be valid right now.
+ # By doing so, we get a new access token and a new refresh token.
+ refresh_response = self.use_refresh_token(refresh_token1)
+ self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
+ self.assertIn(
+ "refresh_token",
+ refresh_response.json_body,
+ "No new refresh token returned after refresh.",
+ )
+ refresh_token2 = refresh_response.json_body["refresh_token"]
+
+ # Advance 121 seconds in the future (just a bit more than 2 minutes)
+ self.reactor.advance(121.0)
+
+ # Try to refresh our session, but instead notice that the refresh token is
+ # not valid (it just expired).
+ refresh_response = self.use_refresh_token(refresh_token2)
+ self.assertEqual(
+ refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
+ )
+
+ @override_config(
+ {
+ "refreshable_access_token_lifetime": "2m",
+ "refresh_token_lifetime": "2m",
+ "session_lifetime": "3m",
+ }
+ )
+ def test_ultimate_session_expiry(self):
+ """
+ The session can be configured to have an ultimate, limited lifetime.
+ """
+
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ "org.matrix.msc2918.refresh_token": True,
+ }
+ login_response = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login",
+ body,
+ )
+ self.assertEqual(login_response.code, 200, login_response.result)
+ refresh_token = login_response.json_body["refresh_token"]
+
+ # Advance shy of 2 minutes into the future
+ self.reactor.advance(119.0)
+
+ # Refresh our session. The refresh token should still be valid right now.
+ refresh_response = self.use_refresh_token(refresh_token)
+ self.assertEqual(refresh_response.code, 200, refresh_response.result)
+ self.assertIn(
+ "refresh_token",
+ refresh_response.json_body,
+ "No new refresh token returned after refresh.",
+ )
+ # Notice that our access token lifetime has been diminished to match the
+ # session lifetime.
+ # 3 minutes - 119 seconds = 61 seconds.
+ self.assertEqual(refresh_response.json_body["expires_in_ms"], 61_000)
+ refresh_token = refresh_response.json_body["refresh_token"]
+
+ # Advance 61 seconds into the future. Our session should have expired
+ # now, because we've had our 3 minutes.
+ self.reactor.advance(61.0)
+
+ # Try to issue a new, refreshed, access token.
+ # This should fail because the refresh token's lifetime has also been
+ # diminished as our session expired.
+ refresh_response = self.use_refresh_token(refresh_token)
+ self.assertEqual(refresh_response.code, 403, refresh_response.result)
def test_refresh_token_invalidation(self):
"""Refresh tokens are invalidated after first use of the next token.
@@ -640,10 +788,15 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|-> fourth_refresh (fails)
"""
- body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ "org.matrix.msc2918.refresh_token": True,
+ }
login_response = self.make_request(
"POST",
- "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
+ "/_matrix/client/r0/login",
body,
)
self.assertEqual(login_response.code, 200, login_response.result)
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index b9e3602552..249808b031 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -71,7 +71,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"localdb_enabled": False}})
def test_get_change_password_capabilities_localdb_disabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
@@ -85,7 +85,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"enabled": False}})
def test_get_change_password_capabilities_password_disabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
@@ -174,7 +174,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"experimental_features": {"msc3244_enabled": False}})
def test_get_does_not_include_msc3244_fields_when_disabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
@@ -189,7 +189,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
def test_get_does_include_msc3244_fields_when_enabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 0b90e3f803..19f5e46537 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -815,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
jwt_secret = "secret"
jwt_algorithm = "HS256"
+ base_config = {
+ "enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ }
- def make_homeserver(self, reactor, clock):
- self.hs = self.setup_test_homeserver()
- self.hs.config.jwt.jwt_enabled = True
- self.hs.config.jwt.jwt_secret = self.jwt_secret
- self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
- return self.hs
+ def default_config(self):
+ config = super().default_config()
+
+ # If jwt_config has been defined (eg via @override_config), don't replace it.
+ if config.get("jwt_config") is None:
+ config["jwt_config"] = self.base_config
+
+ return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
@@ -879,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
- @override_config(
- {
- "jwt_config": {
- "jwt_enabled": True,
- "secret": jwt_secret,
- "algorithm": jwt_algorithm,
- "issuer": "test-issuer",
- }
- }
- )
+ @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
def test_login_iss(self):
"""Test validating the issuer claim."""
# A valid issuer.
@@ -919,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
- @override_config(
- {
- "jwt_config": {
- "jwt_enabled": True,
- "secret": jwt_secret,
- "algorithm": jwt_algorithm,
- "audiences": ["test-audience"],
- }
- }
- )
+ @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
def test_login_aud(self):
"""Test validating the audience claim."""
# A valid audience.
@@ -962,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel.json_body["error"], "JWT validation failed: Invalid audience"
)
+ def test_login_default_sub(self):
+ """Test reading user ID from the default subject claim."""
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
+ def test_login_custom_sub(self):
+ """Test reading user ID from a custom subject claim."""
+ channel = self.jwt_login({"username": "frog"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@frog:test")
+
def test_login_no_token(self):
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
@@ -1024,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
]
)
- def make_homeserver(self, reactor, clock):
- self.hs = self.setup_test_homeserver()
- self.hs.config.jwt.jwt_enabled = True
- self.hs.config.jwt.jwt_secret = self.jwt_pubkey
- self.hs.config.jwt.jwt_algorithm = "RS256"
- return self.hs
+ def default_config(self):
+ config = super().default_config()
+ config["jwt_config"] = {
+ "enabled": True,
+ "secret": self.jwt_pubkey,
+ "algorithm": "RS256",
+ }
+ return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 78c2fb86b9..eb10d43217 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1,4 +1,5 @@
# Copyright 2019 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,6 +47,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return config
def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
self.user_id, self.user_token = self._create_user("alice")
self.user2_id, self.user2_token = self._create_user("bob")
@@ -91,6 +94,49 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
self.assertEquals(400, channel.code, channel.json_body)
+ def test_deny_invalid_event(self):
+ """Test that we deny relations on non-existant events"""
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ EventTypes.Message,
+ parent_id="foo",
+ content={"body": "foo", "msgtype": "m.text"},
+ )
+ self.assertEquals(400, channel.code, channel.json_body)
+
+ # Unless that event is referenced from another event!
+ self.get_success(
+ self.hs.get_datastore().db_pool.simple_insert(
+ table="event_relations",
+ values={
+ "event_id": "bar",
+ "relates_to_id": "foo",
+ "relation_type": RelationTypes.THREAD,
+ },
+ desc="test_deny_invalid_event",
+ )
+ )
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ parent_id="foo",
+ content={"body": "foo", "msgtype": "m.text"},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ def test_deny_invalid_room(self):
+ """Test that we deny relations on non-existant events"""
+ # Create another room and send a message in it.
+ room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
+ res = self.helper.send(room2, body="Hi!", tok=self.user_token)
+ parent_id = res["event_id"]
+
+ # Attempt to send an annotation to that event.
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+ )
+ self.assertEquals(400, channel.code, channel.json_body)
+
def test_deny_double_react(self):
"""Test that we deny relations on membership events"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
@@ -99,6 +145,25 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(400, channel.code, channel.json_body)
+ def test_deny_forked_thread(self):
+ """It is invalid to start a thread off a thread."""
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo"},
+ parent_id=self.parent_id,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ parent_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo"},
+ parent_id=parent_id,
+ )
+ self.assertEquals(400, channel.code, channel.json_body)
+
def test_basic_paginate_relations(self):
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -703,6 +768,52 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertIn("chunk", channel.json_body)
self.assertEquals(channel.json_body["chunk"], [])
+ def test_unknown_relations(self):
+ """Unknown relations should be accepted."""
+ channel = self._send_relation("m.relation.test", "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the full
+ # relation event we sent above.
+ self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"},
+ channel.json_body["chunk"][0],
+ )
+
+ # We also expect to get the original event (the id of which is self.parent_id)
+ self.assertEquals(
+ channel.json_body["original_event"]["event_id"], self.parent_id
+ )
+
+ # When bundling the unknown relation is not included.
+ channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
+ # But unknown relations can be directly queried.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEquals(channel.json_body["chunk"], [])
+
def _send_relation(
self,
relation_type: str,
@@ -749,3 +860,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token = self.login(localpart, "abc123")
return user_id, access_token
+
+ def test_background_update(self):
+ """Test the event_arbitrary_relations background update."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_event_id_good = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_event_id_bad = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_event_id = channel.json_body["event_id"]
+
+ # Clean-up the table as if the inserts did not happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="event_relations",
+ column="event_id",
+ iterable=(annotation_event_id_bad, thread_event_id),
+ keyvalues={},
+ desc="RelationsTestCase.test_background_update",
+ )
+ )
+
+ # Only the "good" annotation should be found.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEquals(
+ [ev["event_id"] for ev in channel.json_body["chunk"]],
+ [annotation_event_id_good],
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "event_arbitrary_relations", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # The "good" annotation and the thread should be found, but not the "bad"
+ # annotation.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertCountEqual(
+ [ev["event_id"] for ev in channel.json_body["chunk"]],
+ [annotation_event_id_good, thread_event_id],
+ )
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py
index 09504a485f..8fe94f7d85 100644
--- a/tests/rest/media/v1/test_filepath.py
+++ b/tests/rest/media/v1/test_filepath.py
@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
+from typing import Iterable
+
from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest
@@ -236,3 +239,250 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge",
],
)
+
+ def test_server_name_validation(self):
+ """Test validation of server names"""
+ self._test_path_validation(
+ [
+ "remote_media_filepath_rel",
+ "remote_media_filepath",
+ "remote_media_thumbnail_rel",
+ "remote_media_thumbnail",
+ "remote_media_thumbnail_rel_legacy",
+ "remote_media_thumbnail_dir",
+ ],
+ parameter="server_name",
+ valid_values=[
+ "matrix.org",
+ "matrix.org:8448",
+ "matrix-federation.matrix.org",
+ "matrix-federation.matrix.org:8448",
+ "10.1.12.123",
+ "10.1.12.123:8448",
+ "[fd00:abcd::ffff]",
+ "[fd00:abcd::ffff]:8448",
+ ],
+ invalid_values=[
+ "/matrix.org",
+ "matrix.org/..",
+ "matrix.org\x00",
+ "",
+ ".",
+ "..",
+ "/",
+ ],
+ )
+
+ def test_file_id_validation(self):
+ """Test validation of local, remote and legacy URL cache file / media IDs"""
+ # File / media IDs get split into three parts to form paths, consisting of the
+ # first two characters, next two characters and rest of the ID.
+ valid_file_ids = [
+ "GerZNDnDZVjsOtardLuwfIBg",
+ # Unexpected, but produces an acceptable path:
+ "GerZN", # "N" becomes the last directory
+ ]
+ invalid_file_ids = [
+ "/erZNDnDZVjsOtardLuwfIBg",
+ "Ge/ZNDnDZVjsOtardLuwfIBg",
+ "GerZ/DnDZVjsOtardLuwfIBg",
+ "GerZ/..",
+ "G\x00rZNDnDZVjsOtardLuwfIBg",
+ "Ger\x00NDnDZVjsOtardLuwfIBg",
+ "GerZNDnDZVjsOtardLuwfIBg\x00",
+ "",
+ "Ge",
+ "GerZ",
+ "GerZ.",
+ "..rZNDnDZVjsOtardLuwfIBg",
+ "Ge..NDnDZVjsOtardLuwfIBg",
+ "GerZ..",
+ "GerZ/",
+ ]
+
+ self._test_path_validation(
+ [
+ "local_media_filepath_rel",
+ "local_media_filepath",
+ "local_media_thumbnail_rel",
+ "local_media_thumbnail",
+ "local_media_thumbnail_dir",
+ # Legacy URL cache media IDs
+ "url_cache_filepath_rel",
+ "url_cache_filepath",
+ # `url_cache_filepath_dirs_to_delete` is tested below.
+ "url_cache_thumbnail_rel",
+ "url_cache_thumbnail",
+ "url_cache_thumbnail_directory_rel",
+ "url_cache_thumbnail_directory",
+ "url_cache_thumbnail_dirs_to_delete",
+ ],
+ parameter="media_id",
+ valid_values=valid_file_ids,
+ invalid_values=invalid_file_ids,
+ )
+
+ # `url_cache_filepath_dirs_to_delete` ignores what would be the last path
+ # component, so only the first 4 characters matter.
+ self._test_path_validation(
+ [
+ "url_cache_filepath_dirs_to_delete",
+ ],
+ parameter="media_id",
+ valid_values=valid_file_ids,
+ invalid_values=[
+ "/erZNDnDZVjsOtardLuwfIBg",
+ "Ge/ZNDnDZVjsOtardLuwfIBg",
+ "G\x00rZNDnDZVjsOtardLuwfIBg",
+ "Ger\x00NDnDZVjsOtardLuwfIBg",
+ "",
+ "Ge",
+ "..rZNDnDZVjsOtardLuwfIBg",
+ "Ge..NDnDZVjsOtardLuwfIBg",
+ ],
+ )
+
+ self._test_path_validation(
+ [
+ "remote_media_filepath_rel",
+ "remote_media_filepath",
+ "remote_media_thumbnail_rel",
+ "remote_media_thumbnail",
+ "remote_media_thumbnail_rel_legacy",
+ "remote_media_thumbnail_dir",
+ ],
+ parameter="file_id",
+ valid_values=valid_file_ids,
+ invalid_values=invalid_file_ids,
+ )
+
+ def test_url_cache_media_id_validation(self):
+ """Test validation of URL cache media IDs"""
+ self._test_path_validation(
+ [
+ "url_cache_filepath_rel",
+ "url_cache_filepath",
+ # `url_cache_filepath_dirs_to_delete` only cares about the date prefix
+ "url_cache_thumbnail_rel",
+ "url_cache_thumbnail",
+ "url_cache_thumbnail_directory_rel",
+ "url_cache_thumbnail_directory",
+ "url_cache_thumbnail_dirs_to_delete",
+ ],
+ parameter="media_id",
+ valid_values=[
+ "2020-01-02_GerZNDnDZVjsOtar",
+ "2020-01-02_G", # Unexpected, but produces an acceptable path
+ ],
+ invalid_values=[
+ "2020-01-02",
+ "2020-01-02-",
+ "2020-01-02-.",
+ "2020-01-02-..",
+ "2020-01-02-/",
+ "2020-01-02-/GerZNDnDZVjsOtar",
+ "2020-01-02-GerZNDnDZVjsOtar/..",
+ "2020-01-02-GerZNDnDZVjsOtar\x00",
+ ],
+ )
+
+ def test_content_type_validation(self):
+ """Test validation of thumbnail content types"""
+ self._test_path_validation(
+ [
+ "local_media_thumbnail_rel",
+ "local_media_thumbnail",
+ "remote_media_thumbnail_rel",
+ "remote_media_thumbnail",
+ "remote_media_thumbnail_rel_legacy",
+ "url_cache_thumbnail_rel",
+ "url_cache_thumbnail",
+ ],
+ parameter="content_type",
+ valid_values=[
+ "image/jpeg",
+ ],
+ invalid_values=[
+ "", # ValueError: not enough values to unpack
+ "image/jpeg/abc", # ValueError: too many values to unpack
+ "image/jpeg\x00",
+ ],
+ )
+
+ def test_thumbnail_method_validation(self):
+ """Test validation of thumbnail methods"""
+ self._test_path_validation(
+ [
+ "local_media_thumbnail_rel",
+ "local_media_thumbnail",
+ "remote_media_thumbnail_rel",
+ "remote_media_thumbnail",
+ "url_cache_thumbnail_rel",
+ "url_cache_thumbnail",
+ ],
+ parameter="method",
+ valid_values=[
+ "crop",
+ "scale",
+ ],
+ invalid_values=[
+ "/scale",
+ "scale/..",
+ "scale\x00",
+ "/",
+ ],
+ )
+
+ def _test_path_validation(
+ self,
+ methods: Iterable[str],
+ parameter: str,
+ valid_values: Iterable[str],
+ invalid_values: Iterable[str],
+ ):
+ """Test that the specified methods validate the named parameter as expected
+
+ Args:
+ methods: The names of `MediaFilePaths` methods to test
+ parameter: The name of the parameter to test
+ valid_values: A list of parameter values that are expected to be accepted
+ invalid_values: A list of parameter values that are expected to be rejected
+
+ Raises:
+ AssertionError: If a value was accepted when it should have failed
+ validation.
+ ValueError: If a value failed validation when it should have been accepted.
+ """
+ for method in methods:
+ get_path = getattr(self.filepaths, method)
+
+ parameters = inspect.signature(get_path).parameters
+ kwargs = {
+ "server_name": "matrix.org",
+ "media_id": "GerZNDnDZVjsOtardLuwfIBg",
+ "file_id": "GerZNDnDZVjsOtardLuwfIBg",
+ "width": 800,
+ "height": 600,
+ "content_type": "image/jpeg",
+ "method": "scale",
+ }
+
+ if get_path.__name__.startswith("url_"):
+ kwargs["media_id"] = "2020-01-02_GerZNDnDZVjsOtar"
+
+ kwargs = {k: v for k, v in kwargs.items() if k in parameters}
+ kwargs.pop(parameter)
+
+ for value in valid_values:
+ kwargs[parameter] = value
+ get_path(**kwargs)
+ # No exception should be raised
+
+ for value in invalid_values:
+ with self.assertRaises(ValueError):
+ kwargs[parameter] = value
+ path_or_list = get_path(**kwargs)
+ self.fail(
+ f"{value!r} unexpectedly passed validation: "
+ f"{method} returned {path_or_list!r}"
+ )
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
index 4b67bd15b7..36c933b9e9 100644
--- a/tests/storage/databases/main/test_deviceinbox.py
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -66,7 +66,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "remove_deleted_devices_from_device_inbox",
+ "update_name": "remove_dead_devices_from_device_inbox",
"progress_json": "{}",
},
)
@@ -140,7 +140,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "remove_hidden_devices_from_device_inbox",
+ "update_name": "remove_dead_devices_from_device_inbox",
"progress_json": "{}",
},
)
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index a649e8c618..5ae491ff5a 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -12,11 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from contextlib import contextmanager
+from typing import Generator
+from twisted.enterprise.adbapi import ConnectionPool
+from twisted.internet.defer import ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.room_versions import EventFormatVersions, RoomVersions
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
-from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.server import HomeServer
+from synapse.storage.databases.main.events_worker import (
+ EVENT_QUEUE_THREADS,
+ EventsWorkerStore,
+)
+from synapse.storage.types import Connection
+from synapse.util import Clock
from synapse.util.async_helpers import yieldable_gather_results
from tests import unittest
@@ -144,3 +157,127 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+
+class DatabaseOutageTestCase(unittest.HomeserverTestCase):
+ """Test event fetching during a database outage."""
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ self.store: EventsWorkerStore = hs.get_datastore()
+
+ self.room_id = f"!room:{hs.hostname}"
+ self.event_ids = [f"event{i}" for i in range(20)]
+
+ self._populate_events()
+
+ def _populate_events(self) -> None:
+ """Ensure that there are test events in the database.
+
+ When testing with the in-memory SQLite database, all the events are lost during
+ the simulated outage.
+
+ To ensure consistency between `room_id`s and `event_id`s before and after the
+ outage, rows are built and inserted manually.
+
+ Upserts are used to handle the non-SQLite case where events are not lost.
+ """
+ self.get_success(
+ self.store.db_pool.simple_upsert(
+ "rooms",
+ {"room_id": self.room_id},
+ {"room_version": RoomVersions.V4.identifier},
+ )
+ )
+
+ self.event_ids = [f"event{i}" for i in range(20)]
+ for idx, event_id in enumerate(self.event_ids):
+ self.get_success(
+ self.store.db_pool.simple_upsert(
+ "events",
+ {"event_id": event_id},
+ {
+ "event_id": event_id,
+ "room_id": self.room_id,
+ "topological_ordering": idx,
+ "stream_ordering": idx,
+ "type": "test",
+ "processed": True,
+ "outlier": False,
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_upsert(
+ "event_json",
+ {"event_id": event_id},
+ {
+ "room_id": self.room_id,
+ "json": json.dumps({"type": "test", "room_id": self.room_id}),
+ "internal_metadata": "{}",
+ "format_version": EventFormatVersions.V3,
+ },
+ )
+ )
+
+ @contextmanager
+ def _outage(self) -> Generator[None, None, None]:
+ """Simulate a database outage.
+
+ Returns:
+ A context manager. While the context is active, any attempts to connect to
+ the database will fail.
+ """
+ connection_pool = self.store.db_pool._db_pool
+
+ # Close all connections and shut down the database `ThreadPool`.
+ connection_pool.close()
+
+ # Restart the database `ThreadPool`.
+ connection_pool.start()
+
+ original_connection_factory = connection_pool.connectionFactory
+
+ def connection_factory(_pool: ConnectionPool) -> Connection:
+ raise Exception("Could not connect to the database.")
+
+ connection_pool.connectionFactory = connection_factory # type: ignore[assignment]
+ try:
+ yield
+ finally:
+ connection_pool.connectionFactory = original_connection_factory
+
+ # If the in-memory SQLite database is being used, all the events are gone.
+ # Restore the test data.
+ self._populate_events()
+
+ def test_failure(self) -> None:
+ """Test that event fetches do not get stuck during a database outage."""
+ with self._outage():
+ failure = self.get_failure(
+ self.store.get_event(self.event_ids[0]), Exception
+ )
+ self.assertEqual(str(failure.value), "Could not connect to the database.")
+
+ def test_recovery(self) -> None:
+ """Test that event fetchers recover after a database outage."""
+ with self._outage():
+ # Kick off a bunch of event fetches but do not pump the reactor
+ event_deferreds = []
+ for event_id in self.event_ids:
+ event_deferreds.append(ensureDeferred(self.store.get_event(event_id)))
+
+ # We should have maxed out on event fetcher threads
+ self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS)
+
+ # All the event fetchers will fail
+ self.pump()
+ self.assertEqual(self.store._event_fetch_ongoing, 0)
+
+ for event_deferred in event_deferreds:
+ failure = self.get_failure(event_deferred, Exception)
+ self.assertEqual(
+ str(failure.value), "Could not connect to the database."
+ )
+
+ # This next event fetch should succeed
+ self.get_success(self.store.get_event(self.event_ids[0]))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 0da42b5ac5..216d816d56 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,8 +1,11 @@
-from unittest.mock import Mock
+from mock import Mock
+
+from twisted.internet.defer import Deferred, ensureDeferred
from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest
+from tests.test_utils import make_awaitable
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
@@ -19,11 +22,11 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
def test_do_background_update(self):
- # the time we claim each update takes
- duration_ms = 42
+ # the time we claim it takes to update one item when running the update
+ duration_ms = 10
# the target runtime for each bg update
- target_background_update_duration_ms = 50000
+ target_background_update_duration_ms = 100
store = self.hs.get_datastore()
self.get_success(
@@ -48,16 +51,14 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = update
self.update_handler.reset_mock()
res = self.get_success(
- self.updates.do_next_background_update(
- target_background_update_duration_ms
- ),
- by=0.1,
+ self.updates.do_next_background_update(False),
+ by=0.01,
)
self.assertFalse(res)
# on the first call, we should get run with the default background update size
self.update_handler.assert_called_once_with(
- {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE
+ {"my_key": 1}, self.updates.MINIMUM_BACKGROUND_BATCH_SIZE
)
# second step: complete the update
@@ -74,16 +75,93 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = update
self.update_handler.reset_mock()
- result = self.get_success(
- self.updates.do_next_background_update(target_background_update_duration_ms)
- )
+ result = self.get_success(self.updates.do_next_background_update(False))
self.assertFalse(result)
self.update_handler.assert_called_once()
# third step: we don't expect to be called any more
self.update_handler.reset_mock()
- result = self.get_success(
- self.updates.do_next_background_update(target_background_update_duration_ms)
- )
+ result = self.get_success(self.updates.do_next_background_update(False))
self.assertTrue(result)
self.assertFalse(self.update_handler.called)
+
+
+class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
+ # the base test class should have run the real bg updates for us
+ self.assertTrue(
+ self.get_success(self.updates.has_completed_background_updates())
+ )
+
+ self.update_deferred = Deferred()
+ self.update_handler = Mock(return_value=self.update_deferred)
+ self.updates.register_background_update_handler(
+ "test_update", self.update_handler
+ )
+
+ # Mock out the AsyncContextManager
+ self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"])
+ self._update_ctx_manager.__aenter__ = Mock(
+ return_value=make_awaitable(None),
+ )
+ self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None))
+
+ # Mock out the `update_handler` callback
+ self._on_update = Mock(return_value=self._update_ctx_manager)
+
+ # Define a default batch size value that's not the same as the internal default
+ # value (100).
+ self._default_batch_size = 500
+
+ # Register the callbacks with more mocks
+ self.hs.get_module_api().register_background_update_controller_callbacks(
+ on_update=self._on_update,
+ min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
+ default_batch_size=Mock(
+ return_value=make_awaitable(self._default_batch_size),
+ ),
+ )
+
+ def test_controller(self):
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": "{}"},
+ )
+ )
+
+ # Set the return value for the context manager.
+ enter_defer = Deferred()
+ self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
+
+ # Start the background update.
+ do_update_d = ensureDeferred(self.updates.do_next_background_update(True))
+
+ self.pump()
+
+ # `run_update` should have been called, but the update handler won't be
+ # called until the `enter_defer` (returned by `__aenter__`) is resolved.
+ self._on_update.assert_called_once_with(
+ "test_update",
+ "master",
+ False,
+ )
+ self.assertFalse(do_update_d.called)
+ self.assertFalse(self.update_deferred.called)
+
+ # Resolving the `enter_defer` should call the update handler, which then
+ # blocks.
+ enter_defer.callback(100)
+ self.pump()
+ self.update_handler.assert_called_once_with({}, self._default_batch_size)
+ self.assertFalse(self.update_deferred.called)
+ self._update_ctx_manager.__aexit__.assert_not_called()
+
+ # Resolving the update handler deferred should cause the
+ # `do_next_background_update` to finish and return
+ self.update_deferred.callback(100)
+ self.pump()
+ self._update_ctx_manager.__aexit__.assert_called()
+ self.get_success(do_update_d)
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index b31c5eb5ec..7b7f6c349e 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -664,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
):
iterations += 1
self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(False), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
@@ -723,7 +723,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
):
iterations += 1
self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(False), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
index a6be9a1bb1..cfc8098af6 100644
--- a/tests/storage/test_rollback_worker.py
+++ b/tests/storage/test_rollback_worker.py
@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List
+from unittest import mock
+
from synapse.app.generic_worker import GenericWorkerServer
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database
@@ -19,6 +22,22 @@ from synapse.storage.schema import SCHEMA_VERSION
from tests.unittest import HomeserverTestCase
+def fake_listdir(filepath: str) -> List[str]:
+ """
+ A fake implementation of os.listdir which we can use to mock out the filesystem.
+
+ Args:
+ filepath: The directory to list files for.
+
+ Returns:
+ A list of files and folders in the directory.
+ """
+ if filepath.endswith("full_schemas"):
+ return [str(SCHEMA_VERSION)]
+
+ return ["99_add_unicorn_to_database.sql"]
+
+
class WorkerSchemaTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
@@ -51,7 +70,7 @@ class WorkerSchemaTests(HomeserverTestCase):
prepare_database(db_conn, db_pool.engine, self.hs.config)
- def test_not_upgraded(self):
+ def test_not_upgraded_old_schema_version(self):
"""Test that workers don't start if the DB has an older schema version"""
db_pool = self.hs.get_datastore().db_pool
db_conn = LoggingDatabaseConnection(
@@ -67,3 +86,34 @@ class WorkerSchemaTests(HomeserverTestCase):
with self.assertRaises(PrepareDatabaseException):
prepare_database(db_conn, db_pool.engine, self.hs.config)
+
+ def test_not_upgraded_current_schema_version_with_outstanding_deltas(self):
+ """
+ Test that workers don't start if the DB is on the current schema version,
+ but there are still outstanding delta migrations to run.
+ """
+ db_pool = self.hs.get_datastore().db_pool
+ db_conn = LoggingDatabaseConnection(
+ db_pool._db_pool.connect(),
+ db_pool.engine,
+ "tests",
+ )
+
+ # Set the schema version of the database to the current version
+ cur = db_conn.cursor()
+ cur.execute("UPDATE schema_version SET version = ?", (SCHEMA_VERSION,))
+
+ db_conn.commit()
+
+ # Path `os.listdir` here to make synapse think that there is a migration
+ # file ready to be run.
+ # Note that we can't patch this function for the whole method, else Synapse
+ # will try to find the file when building the database initially.
+ with mock.patch("os.listdir", mock.Mock(side_effect=fake_listdir)):
+ with self.assertRaises(PrepareDatabaseException):
+ # Synapse should think that there is an outstanding migration file due to
+ # patching 'os.listdir' in the function decorator.
+ #
+ # We expect Synapse to raise an exception to indicate the master process
+ # needs to apply this migration file.
+ prepare_database(db_conn, db_pool.engine, self.hs.config)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 37cf7bb232..7f5b28aed8 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -23,6 +23,7 @@ from synapse.rest import admin
from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.storage import DataStore
+from synapse.storage.background_updates import _BackgroundUpdateHandler
from synapse.storage.roommember import ProfileInfo
from synapse.util import Clock
@@ -391,7 +392,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
with mock.patch.dict(
self.store.db_pool.updates._background_update_handlers,
- populate_user_directory_process_users=mocked_process_users,
+ populate_user_directory_process_users=_BackgroundUpdateHandler(
+ mocked_process_users,
+ ),
):
self._purge_and_rebuild_user_dir()
diff --git a/tests/unittest.py b/tests/unittest.py
index c9a08a3420..eea0903f05 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,12 +331,16 @@ class HomeserverTestCase(TestCase):
time.sleep(0.01)
def wait_for_background_updates(self) -> None:
- """Block until all background database updates have completed."""
+ """Block until all background database updates have completed.
+
+ Note that callers must ensure there's a store property created on the
+ testcase.
+ """
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(False), by=0.1
)
def make_homeserver(self, reactor, clock):
@@ -495,8 +499,7 @@ class HomeserverTestCase(TestCase):
async def run_bg_updates():
with LoggingContext("run_bg_updates"):
- while not await stor.db_pool.updates.has_completed_background_updates():
- await stor.db_pool.updates.do_next_background_update(1)
+ self.get_success(stor.db_pool.updates.run_background_updates(False))
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
diff --git a/tests/utils.py b/tests/utils.py
index cf8ba5c5db..983859120f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -119,7 +119,6 @@ def default_config(name, parse=False):
"enable_registration": True,
"enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret",
- "trusted_third_party_id_servers": [],
"password_providers": [],
"worker_replication_url": "",
"worker_app": None,
|