diff --git a/tests/unittest.py b/tests/unittest.py
index 8a16fd3665..c73195b32b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
import gc
import hashlib
import hmac
+import json
import logging
import secrets
import time
@@ -53,6 +55,7 @@ from twisted.web.server import Request
from synapse import events
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.config._base import Config, RootConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.crypto.event_signing import add_hashes_and_signatures
@@ -67,7 +70,6 @@ from synapse.logging.context import (
)
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
-from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
@@ -124,6 +126,53 @@ def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
return _around
+_TConfig = TypeVar("_TConfig", Config, RootConfig)
+
+
+def deepcopy_config(config: _TConfig) -> _TConfig:
+ new_config: _TConfig
+
+ if isinstance(config, RootConfig):
+ new_config = config.__class__(config.config_files) # type: ignore[arg-type]
+ else:
+ new_config = config.__class__(config.root)
+
+ for attr_name in config.__dict__:
+ if attr_name.startswith("__") or attr_name == "root":
+ continue
+ attr = getattr(config, attr_name)
+ if isinstance(attr, Config):
+ new_attr = deepcopy_config(attr)
+ else:
+ new_attr = attr
+
+ setattr(new_config, attr_name, new_attr)
+
+ return new_config
+
+
+@functools.lru_cache(maxsize=8)
+def _parse_config_dict(config: str) -> RootConfig:
+ config_obj = HomeServerConfig()
+ config_obj.parse_config_dict(json.loads(config), "", "")
+ return config_obj
+
+
+def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig:
+ """Creates a :class:`HomeServerConfig` instance with the given configuration dict.
+
+ This is equivalent to::
+
+ config_obj = HomeServerConfig()
+ config_obj.parse_config_dict(config, "", "")
+
+ but it keeps a cache of `HomeServerConfig` instances and deepcopies them as needed,
+ to avoid validating the whole configuration every time.
+ """
+ config_obj = _parse_config_dict(json.dumps(config, sort_keys=True))
+ return deepcopy_config(config_obj)
+
+
class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
@@ -171,13 +220,20 @@ class TestCase(unittest.TestCase):
#
# The easiest way to do this would be to do a full GC after each test
# run, but that is very expensive. Instead, we disable GC (above) for
- # the duration of the test so that we only need to run a gen-0 GC, which
- # is a lot quicker.
+ # the duration of the test and only run a gen-0 GC, which is a lot
+ # quicker. This doesn't clean up everything, since the TestCase
+ # instance still holds references to objects created during the test,
+ # such as HomeServers, so we do a full GC every so often.
@around(self)
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
gc.collect(0)
+ # Run a full GC every 50 gen-0 GCs.
+ gen0_stats = gc.get_stats()[0]
+ gen0_collections = gen0_stats["collections"]
+ if gen0_collections % 50 == 0:
+ gc.collect()
gc.enable()
set_current_context(SENTINEL_CONTEXT)
@@ -508,7 +564,9 @@ class HomeserverTestCase(TestCase):
client_ip,
)
- def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
+ def setup_test_homeserver(
+ self, name: Optional[str] = None, **kwargs: Any
+ ) -> HomeServer:
"""
Set up the test homeserver, meant to be called by the overridable
make_homeserver. It automatically passes through the test class's
@@ -527,16 +585,25 @@ class HomeserverTestCase(TestCase):
else:
config = kwargs["config"]
+ # The server name can be specified using either the `name` argument or a config
+ # override. The `name` argument takes precedence over any config overrides.
+ if name is not None:
+ config["server_name"] = name
+
# Parse the config from a config dict into a HomeServerConfig
- config_obj = HomeServerConfig()
- config_obj.parse_config_dict(config, "", "")
+ config_obj = make_homeserver_config_obj(config)
kwargs["config"] = config_obj
+ # The server name in the config is now `name`, if provided, or the `server_name`
+ # from a config override, or the default of "test". Whichever it is, we
+ # construct a homeserver with a matching name.
+ kwargs["name"] = config_obj.server.server_name
+
async def run_bg_updates() -> None:
with LoggingContext("run_bg_updates"):
self.get_success(stor.db_pool.updates.run_background_updates(False))
- hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
+ hs = setup_test_homeserver(self.addCleanup, **kwargs)
stor = hs.get_datastores().main
# Run the database background updates, when running against "master".
@@ -790,19 +857,23 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
self.get_success(
- hs.get_datastores().main.store_server_verify_keys(
+ hs.get_datastores().main.store_server_keys_json(
+ self.OTHER_SERVER_NAME,
+ verify_key_id,
from_server=self.OTHER_SERVER_NAME,
- ts_added_ms=clock.time_msec(),
- verify_keys=[
- (
- self.OTHER_SERVER_NAME,
- verify_key_id,
- FetchKeyResult(
- verify_key=verify_key,
- valid_until_ts=clock.time_msec() + 10000,
- ),
- )
- ],
+ ts_now_ms=clock.time_msec(),
+ ts_expires_ms=clock.time_msec() + 10000,
+ key_json_bytes=canonicaljson.encode_canonical_json(
+ {
+ "verify_keys": {
+ verify_key_id: {
+ "key": signedjson.key.encode_verify_key_base64(
+ verify_key
+ )
+ }
+ }
+ }
+ ),
)
)
|