diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py
index c2028a78d6..eac29357cd 100755
--- a/docker/configure_workers_and_start.py
+++ b/docker/configure_workers_and_start.py
@@ -51,6 +51,7 @@ import re
import subprocess
import sys
from collections import defaultdict
+from itertools import chain
from pathlib import Path
from typing import (
Any,
@@ -733,20 +734,14 @@ def generate_worker_files(
# program blocks.
worker_descriptors: List[Dict[str, Any]] = []
- # Upstreams for load-balancing purposes. This dict takes the form of the base worker
- # name to the ports of each worker. For example:
+ # Upstreams for load-balancing purposes. This dict takes the form of the worker
+ # type to the ports of each worker. For example:
# {
- # worker_base_name: {1234, 1235, ...}}
+ # worker_type: {1234, 1235, ...}}
# }
# and will be used to construct 'upstream' nginx directives.
nginx_upstreams: Dict[str, Set[int]] = {}
- # A map that will collect port data for load-balancing upstreams before being
- # reprocessed into nginx_locations. Unfortunately, we cannot just use
- # nginx_locations as there is a typing clash.
- # Format: {"endpoint": {1234, 1235, ...}}
- nginx_preprocessed_locations: Dict[str, Set[int]] = {}
-
# A map of: {"endpoint": "upstream"}, where "upstream" is a str representing what
# will be placed after the proxy_pass directive. The main benefit to representing
# this data as a dict over a str is that we can easily deduplicate endpoints
@@ -764,6 +759,14 @@ def generate_worker_files(
# which exists even if no workers do.
healthcheck_urls = ["http://localhost:8080/health"]
+ # Get the set of all worker types that we have configured
+ all_worker_types_in_use = set(chain(*requested_worker_types.values()))
+ # Map locations to upstreams (corresponding to worker types) in Nginx
+ # but only if we use the appropriate worker type
+ for worker_type in all_worker_types_in_use:
+ for endpoint_pattern in WORKERS_CONFIG[worker_type]["endpoint_patterns"]:
+ nginx_locations[endpoint_pattern] = f"http://{worker_type}"
+
# For each worker type specified by the user, create config values and write it's
# yaml config file
for worker_name, worker_types_set in requested_worker_types.items():
@@ -812,12 +815,6 @@ def generate_worker_files(
# Enable the worker in supervisord
worker_descriptors.append(worker_config)
- # Add nginx location blocks for this worker's endpoints (if any are defined)
- for pattern in worker_config["endpoint_patterns"]:
- # Need more data to determine whether we need to load-balance this worker.
- # Collect all the port numbers for a given endpoint
- nginx_preprocessed_locations.setdefault(pattern, set()).add(worker_port)
-
# Write out the worker's logging config file
log_config_filepath = generate_worker_log_config(environ, worker_name, data_dir)
@@ -829,30 +826,11 @@ def generate_worker_files(
worker_log_config_filepath=log_config_filepath,
)
- worker_port += 1
-
- # Re process all nginx upstream data. Worker_descriptors contains all the port data,
- # cross-reference that with the worker_base_name in requested_worker_types.
- for pattern, port_set in nginx_preprocessed_locations.items():
- upstream_name: Set[str] = set()
- for worker in worker_descriptors:
- # Find the port we want
- if int(worker["port"]) in port_set:
- # Capture the name. We want the base name as they will be grouped
- # together.
- upstream_name.add(
- requested_worker_types[worker["name"]].get("worker_base_name")
- )
-
- # Join it all up nice and pretty with a double underscore
- upstream = "__".join(sorted(upstream_name))
- upstream_location = "http://" + upstream
-
- # Save the upstream location to it's associated pattern
- nginx_locations[pattern] = upstream_location
+ # Save this worker's port number to the correct nginx upstreams
+ for worker_type in worker_types_set:
+ nginx_upstreams.setdefault(worker_type, set()).add(worker_port)
- # And save the port numbers for writing out below
- nginx_upstreams[upstream] = port_set
+ worker_port += 1
# Build the nginx location config blocks
nginx_location_config = ""
@@ -868,7 +846,7 @@ def generate_worker_files(
for upstream_worker_base_name, upstream_worker_ports in nginx_upstreams.items():
body = ""
for port in upstream_worker_ports:
- body += " server localhost:%d;\n" % (port,)
+ body += f" server localhost:{port};\n"
# Add to the list of configured upstreams
nginx_upstream_config += NGINX_UPSTREAM_CONFIG_BLOCK.format(
|