summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xdocker/configure_workers_and_start.py60
1 files changed, 33 insertions, 27 deletions
diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py
index b2a03f075a..80f0a2e542 100755
--- a/docker/configure_workers_and_start.py
+++ b/docker/configure_workers_and_start.py
@@ -55,6 +55,7 @@ import subprocess
 import sys
 from argparse import ArgumentParser
 from collections import defaultdict
+from copy import deepcopy
 from dataclasses import dataclass, field
 from itertools import chain
 from pathlib import Path
@@ -321,37 +322,42 @@ def flush_buffers() -> None:
     sys.stderr.flush()
 
 
-def merge_into(dest: Any, new: Any) -> None:
+def merged(a: Any, b: Any) -> Any:
     """
-    Merges `new` into `dest` with the following rules:
+    Merges `a` and `b` together, returning the result.
+
+    The merge is performed with the following rules:
 
     - dicts: values with the same key will be merged recursively
     - lists: `new` will be appended to `dest`
     - primitives: they will be checked for equality and inequality will result
         in a ValueError
 
-    It is an error for `dest` and `new` to be of different types.
-    """
-    if isinstance(dest, dict) and isinstance(new, dict):
-        for k, v in new.items():
-            if k in dest:
-                merge_into(dest[k], v)
-            else:
-                dest[k] = v
-    elif isinstance(dest, list) and isinstance(new, list):
-        dest.extend(new)
-    elif type(dest) != type(new):
-        raise TypeError(f"Cannot merge {type(dest).__name__} and {type(new).__name__}")
-    elif dest != new:
-        raise ValueError(f"Cannot merge primitive values: {dest!r} != {new!r}")
-
 
-def merged(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]:
+    It is an error for `a` and `b` to be of different types.
     """
-    Merges `b` into `a` and returns `a`. Here because we can't use `merge_into`
-    in a lamba conveniently.
-    """
-    merge_into(a, b)
+    if isinstance(a, dict) and isinstance(b, dict):
+        result = {}
+        for key in set(a.keys()) | set(b.keys()):
+            if key in a and key in b:
+                result[key] = merged(a[key], b[key])
+            elif key in a:
+                result[key] = deepcopy(a[key])
+            else:
+                result[key] = deepcopy(b[key])
+
+        return result
+    elif isinstance(a, list) and isinstance(b, list):
+        return deepcopy(a) + deepcopy(b)
+    elif type(a) != type(b):
+        raise TypeError(f"Cannot merge {type(a).__name__} and {type(b).__name__}")
+    elif a != b:
+        raise ValueError(f"Cannot merge primitive values: {a!r} != {b!r}")
+
+    if type(a) not in {str, int, float, bool, None.__class__}:
+        raise TypeError(
+            f"Cannot use `merged` on type {a} as it may not be safe (must either be an immutable primitive or must have special copy/merge logic)"
+        )
     return a
 
 
@@ -454,10 +460,10 @@ def instantiate_worker_template(
     Returns: worker configuration dictionary
     """
     worker_config_dict = dataclasses.asdict(template)
-    stream_writers_dict = {
-        writer: worker_name for writer in template.stream_writers
-    }
-    worker_config_dict["shared_extra_conf"] = merged(template.shared_extra_conf(worker_name), stream_writers_dict)
+    stream_writers_dict = {writer: worker_name for writer in template.stream_writers}
+    worker_config_dict["shared_extra_conf"] = merged(
+        template.shared_extra_conf(worker_name), stream_writers_dict
+    )
     worker_config_dict["endpoint_patterns"] = sorted(template.endpoint_patterns)
     worker_config_dict["listener_resources"] = sorted(template.listener_resources)
     return worker_config_dict
@@ -786,7 +792,7 @@ def generate_worker_files(
         )
 
         # Update the shared config with any options needed to enable this worker.
-        merge_into(shared_config, worker_config["shared_extra_conf"])
+        shared_config = merged(shared_config, worker_config["shared_extra_conf"])
 
         if using_unix_sockets:
             healthcheck_urls.append(