diff options
author | Olivier Wilkinson (reivilibre) <oliverw@matrix.org> | 2024-01-17 15:02:07 +0000 |
---|---|---|
committer | Olivier Wilkinson (reivilibre) <oliverw@matrix.org> | 2024-01-17 17:18:23 +0000 |
commit | c91ab4bc55f2c31b7bfcb5e76aa0993d3a2cb040 (patch) | |
tree | 25d8bf6e79fc1603e72d4465d077a4869d7752a5 | |
parent | Move `stream_writers` to their own field in the WorkerTemplate (diff) | |
download | synapse-rei/cwas_extension.tar.xz |
Remove `merge_into` and just have `merged` which copies inputs to avoid footguns github/rei/cwas_extension rei/cwas_extension
-rwxr-xr-x | docker/configure_workers_and_start.py | 60 |
1 files changed, 33 insertions, 27 deletions
diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index b2a03f075a..80f0a2e542 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -55,6 +55,7 @@ import subprocess import sys from argparse import ArgumentParser from collections import defaultdict +from copy import deepcopy from dataclasses import dataclass, field from itertools import chain from pathlib import Path @@ -321,37 +322,42 @@ def flush_buffers() -> None: sys.stderr.flush() -def merge_into(dest: Any, new: Any) -> None: +def merged(a: Any, b: Any) -> Any: """ - Merges `new` into `dest` with the following rules: + Merges `a` and `b` together, returning the result. + + The merge is performed with the following rules: - dicts: values with the same key will be merged recursively - lists: `new` will be appended to `dest` - primitives: they will be checked for equality and inequality will result in a ValueError - It is an error for `dest` and `new` to be of different types. - """ - if isinstance(dest, dict) and isinstance(new, dict): - for k, v in new.items(): - if k in dest: - merge_into(dest[k], v) - else: - dest[k] = v - elif isinstance(dest, list) and isinstance(new, list): - dest.extend(new) - elif type(dest) != type(new): - raise TypeError(f"Cannot merge {type(dest).__name__} and {type(new).__name__}") - elif dest != new: - raise ValueError(f"Cannot merge primitive values: {dest!r} != {new!r}") - -def merged(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]: + It is an error for `a` and `b` to be of different types. """ - Merges `b` into `a` and returns `a`. Here because we can't use `merge_into` - in a lamba conveniently. - """ - merge_into(a, b) + if isinstance(a, dict) and isinstance(b, dict): + result = {} + for key in set(a.keys()) | set(b.keys()): + if key in a and key in b: + result[key] = merged(a[key], b[key]) + elif key in a: + result[key] = deepcopy(a[key]) + else: + result[key] = deepcopy(b[key]) + + return result + elif isinstance(a, list) and isinstance(b, list): + return deepcopy(a) + deepcopy(b) + elif type(a) != type(b): + raise TypeError(f"Cannot merge {type(a).__name__} and {type(b).__name__}") + elif a != b: + raise ValueError(f"Cannot merge primitive values: {a!r} != {b!r}") + + if type(a) not in {str, int, float, bool, None.__class__}: + raise TypeError( + f"Cannot use `merged` on type {a} as it may not be safe (must either be an immutable primitive or must have special copy/merge logic)" + ) return a @@ -454,10 +460,10 @@ def instantiate_worker_template( Returns: worker configuration dictionary """ worker_config_dict = dataclasses.asdict(template) - stream_writers_dict = { - writer: worker_name for writer in template.stream_writers - } - worker_config_dict["shared_extra_conf"] = merged(template.shared_extra_conf(worker_name), stream_writers_dict) + stream_writers_dict = {writer: worker_name for writer in template.stream_writers} + worker_config_dict["shared_extra_conf"] = merged( + template.shared_extra_conf(worker_name), stream_writers_dict + ) worker_config_dict["endpoint_patterns"] = sorted(template.endpoint_patterns) worker_config_dict["listener_resources"] = sorted(template.listener_resources) return worker_config_dict @@ -786,7 +792,7 @@ def generate_worker_files( ) # Update the shared config with any options needed to enable this worker. - merge_into(shared_config, worker_config["shared_extra_conf"]) + shared_config = merged(shared_config, worker_config["shared_extra_conf"]) if using_unix_sockets: healthcheck_urls.append( |