summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-06-16 12:44:07 +0100
committerGitHub <noreply@github.com>2020-06-16 12:44:07 +0100
commit03619324fc18632a2907ace4d3e73f3c4dd0b05e (patch)
treef3873d5b6b6007d2ae1d6462fef4efbe7caee816 /synapse
parentMerge branch 'master' into develop (diff)
downloadsynapse-03619324fc18632a2907ace4d3e73f3c4dd0b05e.tar.xz
Create a ListenerConfig object (#7681)
This ended up being a bit more invasive than I'd hoped for (not helped by
generic_worker duplicating some of the code from homeserver), but hopefully
it's an improvement.

The idea is that, rather than storing unstructured `dict`s in the config for
the listener configurations, we instead parse it into a structured
`ListenerConfig` object.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/_base.py8
-rw-r--r--synapse/app/generic_worker.py36
-rw-r--r--synapse/app/homeserver.py50
-rw-r--r--synapse/config/server.py235
-rw-r--r--synapse/config/workers.py24
-rw-r--r--synapse/http/site.py6
-rw-r--r--synapse/python_dependencies.py5
7 files changed, 216 insertions, 148 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index dedff81af3..373a80a4a7 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -20,6 +20,7 @@ import signal
 import socket
 import sys
 import traceback
+from typing import Iterable
 
 from daemonize import Daemonize
 from typing_extensions import NoReturn
@@ -29,6 +30,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory
 
 import synapse
 from synapse.app import check_bind_error
+from synapse.config.server import ListenerConfig
 from synapse.crypto import context_factory
 from synapse.logging.context import PreserveLoggingContext
 from synapse.util.async_helpers import Linearizer
@@ -234,7 +236,7 @@ def refresh_certificate(hs):
         logger.info("Context factories updated.")
 
 
-def start(hs, listeners=None):
+def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
     """
     Start a Synapse server or worker.
 
@@ -245,8 +247,8 @@ def start(hs, listeners=None):
     notify systemd.
 
     Args:
-        hs (synapse.server.HomeServer)
-        listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml)
+        hs: homeserver instance
+        listeners: Listener configuration ('listeners' in homeserver.yaml)
     """
     try:
         # Set up the SIGHUP machinery.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 53c488d211..27a3fc9ed6 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -37,6 +37,7 @@ from synapse.app import _base
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
+from synapse.config.server import ListenerConfig
 from synapse.federation import send_queue
 from synapse.federation.transport.server import TransportLayerServer
 from synapse.handlers.presence import (
@@ -514,13 +515,18 @@ class GenericWorkerSlavedStore(
 class GenericWorkerServer(HomeServer):
     DATASTORE_CLASS = GenericWorkerSlavedStore
 
-    def _listen_http(self, listener_config):
-        port = listener_config["port"]
-        bind_addresses = listener_config["bind_addresses"]
-        site_tag = listener_config.get("tag", port)
+    def _listen_http(self, listener_config: ListenerConfig):
+        port = listener_config.port
+        bind_addresses = listener_config.bind_addresses
+
+        assert listener_config.http_options is not None
+
+        site_tag = listener_config.http_options.tag
+        if site_tag is None:
+            site_tag = port
         resources = {}
-        for res in listener_config["resources"]:
-            for name in res["names"]:
+        for res in listener_config.http_options.resources:
+            for name in res.names:
                 if name == "metrics":
                     resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
                 elif name == "client":
@@ -590,7 +596,7 @@ class GenericWorkerServer(HomeServer):
                             " repository is disabled. Ignoring."
                         )
 
-                if name == "openid" and "federation" not in res["names"]:
+                if name == "openid" and "federation" not in res.names:
                     # Only load the openid resource separately if federation resource
                     # is not specified since federation resource includes openid
                     # resource.
@@ -625,19 +631,19 @@ class GenericWorkerServer(HomeServer):
 
         logger.info("Synapse worker now listening on port %d", port)
 
-    def start_listening(self, listeners):
+    def start_listening(self, listeners: Iterable[ListenerConfig]):
         for listener in listeners:
-            if listener["type"] == "http":
+            if listener.type == "http":
                 self._listen_http(listener)
-            elif listener["type"] == "manhole":
+            elif listener.type == "manhole":
                 _base.listen_tcp(
-                    listener["bind_addresses"],
-                    listener["port"],
+                    listener.bind_addresses,
+                    listener.port,
                     manhole(
                         username="matrix", password="rabbithole", globals={"hs": self}
                     ),
                 )
-            elif listener["type"] == "metrics":
+            elif listener.type == "metrics":
                 if not self.get_config().enable_metrics:
                     logger.warning(
                         (
@@ -646,9 +652,9 @@ class GenericWorkerServer(HomeServer):
                         )
                     )
                 else:
-                    _base.listen_metrics(listener["bind_addresses"], listener["port"])
+                    _base.listen_metrics(listener.bind_addresses, listener.port)
             else:
-                logger.warning("Unrecognized listener type: %s", listener["type"])
+                logger.warning("Unsupported listener type: %s", listener.type)
 
         self.get_tcp_replication().start_replication(self)
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 93bc45208e..299134d00f 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -23,6 +23,7 @@ import math
 import os
 import resource
 import sys
+from typing import Iterable
 
 from prometheus_client import Gauge
 
@@ -48,6 +49,7 @@ from synapse.app import _base
 from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import ListenerConfig
 from synapse.federation.transport.server import TransportLayerServer
 from synapse.http.additional_resource import AdditionalResource
 from synapse.http.server import (
@@ -87,24 +89,24 @@ def gz_wrap(r):
 class SynapseHomeServer(HomeServer):
     DATASTORE_CLASS = DataStore
 
-    def _listener_http(self, config, listener_config):
-        port = listener_config["port"]
-        bind_addresses = listener_config["bind_addresses"]
-        tls = listener_config.get("tls", False)
-        site_tag = listener_config.get("tag", port)
+    def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig):
+        port = listener_config.port
+        bind_addresses = listener_config.bind_addresses
+        tls = listener_config.tls
+        site_tag = listener_config.http_options.tag
+        if site_tag is None:
+            site_tag = port
 
         resources = {}
-        for res in listener_config["resources"]:
-            for name in res["names"]:
-                if name == "openid" and "federation" in res["names"]:
+        for res in listener_config.http_options.resources:
+            for name in res.names:
+                if name == "openid" and "federation" in res.names:
                     # Skip loading openid resource if federation is defined
                     # since federation resource will include openid
                     continue
-                resources.update(
-                    self._configure_named_resource(name, res.get("compress", False))
-                )
+                resources.update(self._configure_named_resource(name, res.compress))
 
-        additional_resources = listener_config.get("additional_resources", {})
+        additional_resources = listener_config.http_options.additional_resources
         logger.debug("Configuring additional resources: %r", additional_resources)
         module_api = ModuleApi(self, self.get_auth_handler())
         for path, resmodule in additional_resources.items():
@@ -276,7 +278,7 @@ class SynapseHomeServer(HomeServer):
 
         return resources
 
-    def start_listening(self, listeners):
+    def start_listening(self, listeners: Iterable[ListenerConfig]):
         config = self.get_config()
 
         if config.redis_enabled:
@@ -286,25 +288,25 @@ class SynapseHomeServer(HomeServer):
             self.get_tcp_replication().start_replication(self)
 
         for listener in listeners:
-            if listener["type"] == "http":
+            if listener.type == "http":
                 self._listening_services.extend(self._listener_http(config, listener))
-            elif listener["type"] == "manhole":
+            elif listener.type == "manhole":
                 listen_tcp(
-                    listener["bind_addresses"],
-                    listener["port"],
+                    listener.bind_addresses,
+                    listener.port,
                     manhole(
                         username="matrix", password="rabbithole", globals={"hs": self}
                     ),
                 )
-            elif listener["type"] == "replication":
+            elif listener.type == "replication":
                 services = listen_tcp(
-                    listener["bind_addresses"],
-                    listener["port"],
+                    listener.bind_addresses,
+                    listener.port,
                     ReplicationStreamProtocolFactory(self),
                 )
                 for s in services:
                     reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
-            elif listener["type"] == "metrics":
+            elif listener.type == "metrics":
                 if not self.get_config().enable_metrics:
                     logger.warning(
                         (
@@ -313,9 +315,11 @@ class SynapseHomeServer(HomeServer):
                         )
                     )
                 else:
-                    _base.listen_metrics(listener["bind_addresses"], listener["port"])
+                    _base.listen_metrics(listener.bind_addresses, listener.port)
             else:
-                logger.warning("Unrecognized listener type: %s", listener["type"])
+                # this shouldn't happen, as the listener type should have been checked
+                # during parsing
+                logger.warning("Unrecognized listener type: %s", listener.type)
 
 
 # Gauges to expose monthly active user control metrics
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 73226e63d5..8204664883 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,7 +19,7 @@ import logging
 import os.path
 import re
 from textwrap import indent
-from typing import Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional
 
 import attr
 import yaml
@@ -57,6 +57,64 @@ on how to configure the new listener.
 --------------------------------------------------------------------------------"""
 
 
+KNOWN_LISTENER_TYPES = {
+    "http",
+    "metrics",
+    "manhole",
+    "replication",
+}
+
+KNOWN_RESOURCES = {
+    "client",
+    "consent",
+    "federation",
+    "keys",
+    "media",
+    "metrics",
+    "openid",
+    "replication",
+    "static",
+    "webclient",
+}
+
+
+@attr.s(frozen=True)
+class HttpResourceConfig:
+    names = attr.ib(
+        type=List[str],
+        factory=list,
+        validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)),  # type: ignore
+    )
+    compress = attr.ib(
+        type=bool,
+        default=False,
+        validator=attr.validators.optional(attr.validators.instance_of(bool)),  # type: ignore[arg-type]
+    )
+
+
+@attr.s(frozen=True)
+class HttpListenerConfig:
+    """Object describing the http-specific parts of the config of a listener"""
+
+    x_forwarded = attr.ib(type=bool, default=False)
+    resources = attr.ib(type=List[HttpResourceConfig], factory=list)
+    additional_resources = attr.ib(type=Dict[str, dict], factory=dict)
+    tag = attr.ib(type=str, default=None)
+
+
+@attr.s(frozen=True)
+class ListenerConfig:
+    """Object describing the configuration of a single listener."""
+
+    port = attr.ib(type=int, validator=attr.validators.instance_of(int))
+    bind_addresses = attr.ib(type=List[str])
+    type = attr.ib(type=str, validator=attr.validators.in_(KNOWN_LISTENER_TYPES))
+    tls = attr.ib(type=bool, default=False)
+
+    # http_options is only populated if type=http
+    http_options = attr.ib(type=Optional[HttpListenerConfig], default=None)
+
+
 class ServerConfig(Config):
     section = "server"
 
@@ -379,38 +437,21 @@ class ServerConfig(Config):
                 }
             ]
 
-        self.listeners = []  # type: List[dict]
-        for listener in config.get("listeners", []):
-            if not isinstance(listener.get("port", None), int):
-                raise ConfigError(
-                    "Listener configuration is lacking a valid 'port' option"
-                )
+        self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])]
 
-            if listener.setdefault("tls", False):
-                # no_tls is not really supported any more, but let's grandfather it in
-                # here.
-                if config.get("no_tls", False):
+        # no_tls is not really supported any more, but let's grandfather it in
+        # here.
+        if config.get("no_tls", False):
+            l2 = []
+            for listener in self.listeners:
+                if listener.tls:
                     logger.info(
-                        "Ignoring TLS-enabled listener on port %i due to no_tls"
+                        "Ignoring TLS-enabled listener on port %i due to no_tls",
+                        listener.port,
                     )
-                    continue
-
-            bind_address = listener.pop("bind_address", None)
-            bind_addresses = listener.setdefault("bind_addresses", [])
-
-            # if bind_address was specified, add it to the list of addresses
-            if bind_address:
-                bind_addresses.append(bind_address)
-
-            # if we still have an empty list of addresses, use the default list
-            if not bind_addresses:
-                if listener["type"] == "metrics":
-                    # the metrics listener doesn't support IPv6
-                    bind_addresses.append("0.0.0.0")
                 else:
-                    bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
-
-            self.listeners.append(listener)
+                    l2.append(listener)
+            self.listeners = l2
 
         if not self.web_client_location:
             _warn_if_webclient_configured(self.listeners)
@@ -446,43 +487,41 @@ class ServerConfig(Config):
             bind_host = config.get("bind_host", "")
             gzip_responses = config.get("gzip_responses", True)
 
+            http_options = HttpListenerConfig(
+                resources=[
+                    HttpResourceConfig(names=["client"], compress=gzip_responses),
+                    HttpResourceConfig(names=["federation"]),
+                ],
+            )
+
             self.listeners.append(
-                {
-                    "port": bind_port,
-                    "bind_addresses": [bind_host],
-                    "tls": True,
-                    "type": "http",
-                    "resources": [
-                        {"names": ["client"], "compress": gzip_responses},
-                        {"names": ["federation"], "compress": False},
-                    ],
-                }
+                ListenerConfig(
+                    port=bind_port,
+                    bind_addresses=[bind_host],
+                    tls=True,
+                    type="http",
+                    http_options=http_options,
+                )
             )
 
             unsecure_port = config.get("unsecure_port", bind_port - 400)
             if unsecure_port:
                 self.listeners.append(
-                    {
-                        "port": unsecure_port,
-                        "bind_addresses": [bind_host],
-                        "tls": False,
-                        "type": "http",
-                        "resources": [
-                            {"names": ["client"], "compress": gzip_responses},
-                            {"names": ["federation"], "compress": False},
-                        ],
-                    }
+                    ListenerConfig(
+                        port=unsecure_port,
+                        bind_addresses=[bind_host],
+                        tls=False,
+                        type="http",
+                        http_options=http_options,
+                    )
                 )
 
         manhole = config.get("manhole")
         if manhole:
             self.listeners.append(
-                {
-                    "port": manhole,
-                    "bind_addresses": ["127.0.0.1"],
-                    "type": "manhole",
-                    "tls": False,
-                }
+                ListenerConfig(
+                    port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+                )
             )
 
         metrics_port = config.get("metrics_port")
@@ -490,13 +529,14 @@ class ServerConfig(Config):
             logger.warning(METRICS_PORT_WARNING)
 
             self.listeners.append(
-                {
-                    "port": metrics_port,
-                    "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
-                    "tls": False,
-                    "type": "http",
-                    "resources": [{"names": ["metrics"], "compress": False}],
-                }
+                ListenerConfig(
+                    port=metrics_port,
+                    bind_addresses=[config.get("metrics_bind_host", "127.0.0.1")],
+                    type="http",
+                    http_options=HttpListenerConfig(
+                        resources=[HttpResourceConfig(names=["metrics"])]
+                    ),
+                )
             )
 
         _check_resource_config(self.listeners)
@@ -522,7 +562,7 @@ class ServerConfig(Config):
         )
 
     def has_tls_listener(self) -> bool:
-        return any(listener["tls"] for listener in self.listeners)
+        return any(listener.tls for listener in self.listeners)
 
     def generate_config_section(
         self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
@@ -1081,6 +1121,44 @@ def read_gc_thresholds(thresholds):
         )
 
 
+def parse_listener_def(listener: Any) -> ListenerConfig:
+    """parse a listener config from the config file"""
+    listener_type = listener["type"]
+
+    port = listener.get("port")
+    if not isinstance(port, int):
+        raise ConfigError("Listener configuration is lacking a valid 'port' option")
+
+    tls = listener.get("tls", False)
+
+    bind_addresses = listener.get("bind_addresses", [])
+    bind_address = listener.get("bind_address")
+    # if bind_address was specified, add it to the list of addresses
+    if bind_address:
+        bind_addresses.append(bind_address)
+
+    # if we still have an empty list of addresses, use the default list
+    if not bind_addresses:
+        if listener_type == "metrics":
+            # the metrics listener doesn't support IPv6
+            bind_addresses.append("0.0.0.0")
+        else:
+            bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
+
+    http_config = None
+    if listener_type == "http":
+        http_config = HttpListenerConfig(
+            x_forwarded=listener.get("x_forwarded", False),
+            resources=[
+                HttpResourceConfig(**res) for res in listener.get("resources", [])
+            ],
+            additional_resources=listener.get("additional_resources", {}),
+            tag=listener.get("tag"),
+        )
+
+    return ListenerConfig(port, bind_addresses, listener_type, tls, http_config)
+
+
 NO_MORE_WEB_CLIENT_WARNING = """
 Synapse no longer includes a web client. To enable a web client, configure
 web_client_location. To remove this warning, remove 'webclient' from the 'listeners'
@@ -1088,40 +1166,27 @@ configuration.
 """
 
 
-def _warn_if_webclient_configured(listeners):
+def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
     for listener in listeners:
-        for res in listener.get("resources", []):
-            for name in res.get("names", []):
+        if not listener.http_options:
+            continue
+        for res in listener.http_options.resources:
+            for name in res.names:
                 if name == "webclient":
                     logger.warning(NO_MORE_WEB_CLIENT_WARNING)
                     return
 
 
-KNOWN_RESOURCES = (
-    "client",
-    "consent",
-    "federation",
-    "keys",
-    "media",
-    "metrics",
-    "openid",
-    "replication",
-    "static",
-    "webclient",
-)
-
-
-def _check_resource_config(listeners):
+def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None:
     resource_names = {
         res_name
         for listener in listeners
-        for res in listener.get("resources", [])
-        for res_name in res.get("names", [])
+        if listener.http_options
+        for res in listener.http_options.resources
+        for res_name in res.names
     }
 
     for resource in resource_names:
-        if resource not in KNOWN_RESOURCES:
-            raise ConfigError("Unknown listener resource '%s'" % (resource,))
         if resource == "consent":
             try:
                 check_requirements("resources.consent")
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index ed06b91a54..dbc661630c 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -16,6 +16,7 @@
 import attr
 
 from ._base import Config, ConfigError
+from .server import ListenerConfig, parse_listener_def
 
 
 @attr.s
@@ -52,7 +53,9 @@ class WorkerConfig(Config):
         if self.worker_app == "synapse.app.homeserver":
             self.worker_app = None
 
-        self.worker_listeners = config.get("worker_listeners", [])
+        self.worker_listeners = [
+            parse_listener_def(x) for x in config.get("worker_listeners", [])
+        ]
         self.worker_daemonize = config.get("worker_daemonize")
         self.worker_pid_file = config.get("worker_pid_file")
         self.worker_log_config = config.get("worker_log_config")
@@ -75,24 +78,11 @@ class WorkerConfig(Config):
         manhole = config.get("worker_manhole")
         if manhole:
             self.worker_listeners.append(
-                {
-                    "port": manhole,
-                    "bind_addresses": ["127.0.0.1"],
-                    "type": "manhole",
-                    "tls": False,
-                }
+                ListenerConfig(
+                    port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+                )
             )
 
-        if self.worker_listeners:
-            for listener in self.worker_listeners:
-                bind_address = listener.pop("bind_address", None)
-                bind_addresses = listener.setdefault("bind_addresses", [])
-
-                if bind_address:
-                    bind_addresses.append(bind_address)
-                elif not bind_addresses:
-                    bind_addresses.append("")
-
         # A map from instance name to host/port of their HTTP replication endpoint.
         instance_map = config.get("instance_map") or {}
         self.instance_map = {
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 167293c46d..cbc37eac6e 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -19,6 +19,7 @@ from typing import Optional
 from twisted.python.failure import Failure
 from twisted.web.server import Request, Site
 
+from synapse.config.server import ListenerConfig
 from synapse.http import redact_uri
 from synapse.http.request_metrics import RequestMetrics, requests_counter
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
@@ -350,7 +351,7 @@ class SynapseSite(Site):
         self,
         logger_name,
         site_tag,
-        config,
+        config: ListenerConfig,
         resource,
         server_version_string,
         *args,
@@ -360,7 +361,8 @@ class SynapseSite(Site):
 
         self.site_tag = site_tag
 
-        proxied = config.get("x_forwarded", False)
+        assert config.http_options is not None
+        proxied = config.http_options.x_forwarded
         self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
         self.access_logger = logging.getLogger(logger_name)
         self.server_version_string = server_version_string.encode("ascii")
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 8b4312e5a3..8ec1a619a2 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -68,9 +68,8 @@ REQUIREMENTS = [
     "phonenumbers>=8.2.0",
     "six>=1.10",
     "prometheus_client>=0.0.18,<0.8.0",
-    # we use attr.s(slots), which arrived in 16.0.0
-    # Twisted 18.7.0 requires attrs>=17.4.0
-    "attrs>=17.4.0",
+    # we use attr.validators.deep_iterable, which arrived in 19.1.0
+    "attrs>=19.1.0",
     "netaddr>=0.7.18",
     "Jinja2>=2.9",
     "bleach>=1.4.3",