summary refs log tree commit diff
path: root/synapse/config
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/config')
-rw-r--r--synapse/config/server.py235
-rw-r--r--synapse/config/workers.py24
2 files changed, 157 insertions, 102 deletions
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 73226e63d5..8204664883 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,7 +19,7 @@ import logging
 import os.path
 import re
 from textwrap import indent
-from typing import Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional
 
 import attr
 import yaml
@@ -57,6 +57,64 @@ on how to configure the new listener.
 --------------------------------------------------------------------------------"""
 
 
+KNOWN_LISTENER_TYPES = {
+    "http",
+    "metrics",
+    "manhole",
+    "replication",
+}
+
+KNOWN_RESOURCES = {
+    "client",
+    "consent",
+    "federation",
+    "keys",
+    "media",
+    "metrics",
+    "openid",
+    "replication",
+    "static",
+    "webclient",
+}
+
+
+@attr.s(frozen=True)
+class HttpResourceConfig:
+    names = attr.ib(
+        type=List[str],
+        factory=list,
+        validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)),  # type: ignore
+    )
+    compress = attr.ib(
+        type=bool,
+        default=False,
+        validator=attr.validators.optional(attr.validators.instance_of(bool)),  # type: ignore[arg-type]
+    )
+
+
+@attr.s(frozen=True)
+class HttpListenerConfig:
+    """Object describing the http-specific parts of the config of a listener"""
+
+    x_forwarded = attr.ib(type=bool, default=False)
+    resources = attr.ib(type=List[HttpResourceConfig], factory=list)
+    additional_resources = attr.ib(type=Dict[str, dict], factory=dict)
+    tag = attr.ib(type=str, default=None)
+
+
+@attr.s(frozen=True)
+class ListenerConfig:
+    """Object describing the configuration of a single listener."""
+
+    port = attr.ib(type=int, validator=attr.validators.instance_of(int))
+    bind_addresses = attr.ib(type=List[str])
+    type = attr.ib(type=str, validator=attr.validators.in_(KNOWN_LISTENER_TYPES))
+    tls = attr.ib(type=bool, default=False)
+
+    # http_options is only populated if type=http
+    http_options = attr.ib(type=Optional[HttpListenerConfig], default=None)
+
+
 class ServerConfig(Config):
     section = "server"
 
@@ -379,38 +437,21 @@ class ServerConfig(Config):
                 }
             ]
 
-        self.listeners = []  # type: List[dict]
-        for listener in config.get("listeners", []):
-            if not isinstance(listener.get("port", None), int):
-                raise ConfigError(
-                    "Listener configuration is lacking a valid 'port' option"
-                )
+        self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])]
 
-            if listener.setdefault("tls", False):
-                # no_tls is not really supported any more, but let's grandfather it in
-                # here.
-                if config.get("no_tls", False):
+        # no_tls is not really supported any more, but let's grandfather it in
+        # here.
+        if config.get("no_tls", False):
+            l2 = []
+            for listener in self.listeners:
+                if listener.tls:
                     logger.info(
-                        "Ignoring TLS-enabled listener on port %i due to no_tls"
+                        "Ignoring TLS-enabled listener on port %i due to no_tls",
+                        listener.port,
                     )
-                    continue
-
-            bind_address = listener.pop("bind_address", None)
-            bind_addresses = listener.setdefault("bind_addresses", [])
-
-            # if bind_address was specified, add it to the list of addresses
-            if bind_address:
-                bind_addresses.append(bind_address)
-
-            # if we still have an empty list of addresses, use the default list
-            if not bind_addresses:
-                if listener["type"] == "metrics":
-                    # the metrics listener doesn't support IPv6
-                    bind_addresses.append("0.0.0.0")
                 else:
-                    bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
-
-            self.listeners.append(listener)
+                    l2.append(listener)
+            self.listeners = l2
 
         if not self.web_client_location:
             _warn_if_webclient_configured(self.listeners)
@@ -446,43 +487,41 @@ class ServerConfig(Config):
             bind_host = config.get("bind_host", "")
             gzip_responses = config.get("gzip_responses", True)
 
+            http_options = HttpListenerConfig(
+                resources=[
+                    HttpResourceConfig(names=["client"], compress=gzip_responses),
+                    HttpResourceConfig(names=["federation"]),
+                ],
+            )
+
             self.listeners.append(
-                {
-                    "port": bind_port,
-                    "bind_addresses": [bind_host],
-                    "tls": True,
-                    "type": "http",
-                    "resources": [
-                        {"names": ["client"], "compress": gzip_responses},
-                        {"names": ["federation"], "compress": False},
-                    ],
-                }
+                ListenerConfig(
+                    port=bind_port,
+                    bind_addresses=[bind_host],
+                    tls=True,
+                    type="http",
+                    http_options=http_options,
+                )
             )
 
             unsecure_port = config.get("unsecure_port", bind_port - 400)
             if unsecure_port:
                 self.listeners.append(
-                    {
-                        "port": unsecure_port,
-                        "bind_addresses": [bind_host],
-                        "tls": False,
-                        "type": "http",
-                        "resources": [
-                            {"names": ["client"], "compress": gzip_responses},
-                            {"names": ["federation"], "compress": False},
-                        ],
-                    }
+                    ListenerConfig(
+                        port=unsecure_port,
+                        bind_addresses=[bind_host],
+                        tls=False,
+                        type="http",
+                        http_options=http_options,
+                    )
                 )
 
         manhole = config.get("manhole")
         if manhole:
             self.listeners.append(
-                {
-                    "port": manhole,
-                    "bind_addresses": ["127.0.0.1"],
-                    "type": "manhole",
-                    "tls": False,
-                }
+                ListenerConfig(
+                    port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+                )
             )
 
         metrics_port = config.get("metrics_port")
@@ -490,13 +529,14 @@ class ServerConfig(Config):
             logger.warning(METRICS_PORT_WARNING)
 
             self.listeners.append(
-                {
-                    "port": metrics_port,
-                    "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
-                    "tls": False,
-                    "type": "http",
-                    "resources": [{"names": ["metrics"], "compress": False}],
-                }
+                ListenerConfig(
+                    port=metrics_port,
+                    bind_addresses=[config.get("metrics_bind_host", "127.0.0.1")],
+                    type="http",
+                    http_options=HttpListenerConfig(
+                        resources=[HttpResourceConfig(names=["metrics"])]
+                    ),
+                )
             )
 
         _check_resource_config(self.listeners)
@@ -522,7 +562,7 @@ class ServerConfig(Config):
         )
 
     def has_tls_listener(self) -> bool:
-        return any(listener["tls"] for listener in self.listeners)
+        return any(listener.tls for listener in self.listeners)
 
     def generate_config_section(
         self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
@@ -1081,6 +1121,44 @@ def read_gc_thresholds(thresholds):
         )
 
 
+def parse_listener_def(listener: Any) -> ListenerConfig:
+    """parse a listener config from the config file"""
+    listener_type = listener["type"]
+
+    port = listener.get("port")
+    if not isinstance(port, int):
+        raise ConfigError("Listener configuration is lacking a valid 'port' option")
+
+    tls = listener.get("tls", False)
+
+    bind_addresses = listener.get("bind_addresses", [])
+    bind_address = listener.get("bind_address")
+    # if bind_address was specified, add it to the list of addresses
+    if bind_address:
+        bind_addresses.append(bind_address)
+
+    # if we still have an empty list of addresses, use the default list
+    if not bind_addresses:
+        if listener_type == "metrics":
+            # the metrics listener doesn't support IPv6
+            bind_addresses.append("0.0.0.0")
+        else:
+            bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
+
+    http_config = None
+    if listener_type == "http":
+        http_config = HttpListenerConfig(
+            x_forwarded=listener.get("x_forwarded", False),
+            resources=[
+                HttpResourceConfig(**res) for res in listener.get("resources", [])
+            ],
+            additional_resources=listener.get("additional_resources", {}),
+            tag=listener.get("tag"),
+        )
+
+    return ListenerConfig(port, bind_addresses, listener_type, tls, http_config)
+
+
 NO_MORE_WEB_CLIENT_WARNING = """
 Synapse no longer includes a web client. To enable a web client, configure
 web_client_location. To remove this warning, remove 'webclient' from the 'listeners'
@@ -1088,40 +1166,27 @@ configuration.
 """
 
 
-def _warn_if_webclient_configured(listeners):
+def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
     for listener in listeners:
-        for res in listener.get("resources", []):
-            for name in res.get("names", []):
+        if not listener.http_options:
+            continue
+        for res in listener.http_options.resources:
+            for name in res.names:
                 if name == "webclient":
                     logger.warning(NO_MORE_WEB_CLIENT_WARNING)
                     return
 
 
-KNOWN_RESOURCES = (
-    "client",
-    "consent",
-    "federation",
-    "keys",
-    "media",
-    "metrics",
-    "openid",
-    "replication",
-    "static",
-    "webclient",
-)
-
-
-def _check_resource_config(listeners):
+def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None:
     resource_names = {
         res_name
         for listener in listeners
-        for res in listener.get("resources", [])
-        for res_name in res.get("names", [])
+        if listener.http_options
+        for res in listener.http_options.resources
+        for res_name in res.names
     }
 
     for resource in resource_names:
-        if resource not in KNOWN_RESOURCES:
-            raise ConfigError("Unknown listener resource '%s'" % (resource,))
         if resource == "consent":
             try:
                 check_requirements("resources.consent")
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index ed06b91a54..dbc661630c 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -16,6 +16,7 @@
 import attr
 
 from ._base import Config, ConfigError
+from .server import ListenerConfig, parse_listener_def
 
 
 @attr.s
@@ -52,7 +53,9 @@ class WorkerConfig(Config):
         if self.worker_app == "synapse.app.homeserver":
             self.worker_app = None
 
-        self.worker_listeners = config.get("worker_listeners", [])
+        self.worker_listeners = [
+            parse_listener_def(x) for x in config.get("worker_listeners", [])
+        ]
         self.worker_daemonize = config.get("worker_daemonize")
         self.worker_pid_file = config.get("worker_pid_file")
         self.worker_log_config = config.get("worker_log_config")
@@ -75,24 +78,11 @@ class WorkerConfig(Config):
         manhole = config.get("worker_manhole")
         if manhole:
             self.worker_listeners.append(
-                {
-                    "port": manhole,
-                    "bind_addresses": ["127.0.0.1"],
-                    "type": "manhole",
-                    "tls": False,
-                }
+                ListenerConfig(
+                    port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+                )
             )
 
-        if self.worker_listeners:
-            for listener in self.worker_listeners:
-                bind_address = listener.pop("bind_address", None)
-                bind_addresses = listener.setdefault("bind_addresses", [])
-
-                if bind_address:
-                    bind_addresses.append(bind_address)
-                elif not bind_addresses:
-                    bind_addresses.append("")
-
         # A map from instance name to host/port of their HTTP replication endpoint.
         instance_map = config.get("instance_map") or {}
         self.instance_map = {