summary refs log tree commit diff
path: root/synapse/config
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/config')
-rw-r--r--synapse/config/_util.py28
-rw-r--r--synapse/config/workers.py52
2 files changed, 70 insertions, 10 deletions
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index d3a4b484ab..dfc5d12210 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -11,9 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Iterable
+from typing import Any, Dict, Iterable, Type, TypeVar
 
 import jsonschema
+from pydantic import BaseModel, ValidationError, parse_obj_as
 
 from synapse.config._base import ConfigError
 from synapse.types import JsonDict
@@ -64,3 +65,28 @@ def json_error_to_config_error(
         else:
             path.append(str(p))
     return ConfigError(e.message, path)
+
+
+Model = TypeVar("Model", bound=BaseModel)
+
+
+def parse_and_validate_mapping(
+    config: Any,
+    model_type: Type[Model],
+) -> Dict[str, Model]:
+    """Parse `config` as a mapping from strings to a given `Model` type.
+    Args:
+        config: The configuration data to check
+        model_type: The BaseModel to validate and parse against.
+    Returns:
+        Fully validated and parsed Dict[str, Model].
+    Raises:
+        ConfigError, if given improper input.
+    """
+    try:
+        # type-ignore: mypy doesn't like constructing `Dict[str, model_type]` because
+        # `model_type` is a runtime variable. Pydantic is fine with this.
+        instances = parse_obj_as(Dict[str, model_type], config)  # type: ignore[valid-type]
+    except ValidationError as e:
+        raise ConfigError(str(e)) from e
+    return instances
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 1dfbe27e89..95b4047f1d 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -18,6 +18,7 @@ import logging
 from typing import Any, Dict, List, Union
 
 import attr
+from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
 
 from synapse.config._base import (
     Config,
@@ -25,6 +26,7 @@ from synapse.config._base import (
     RoutableShardedWorkerHandlingConfig,
     ShardedWorkerHandlingConfig,
 )
+from synapse.config._util import parse_and_validate_mapping
 from synapse.config.server import (
     DIRECT_TCP_ERROR,
     TCPListenerConfig,
@@ -50,13 +52,43 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
     return obj
 
 
-@attr.s(auto_attribs=True)
-class InstanceLocationConfig:
+class ConfigModel(BaseModel):
+    """A custom version of Pydantic's BaseModel which
+
+     - ignores unknown fields and
+     - does not allow fields to be overwritten after construction,
+
+    but otherwise uses Pydantic's default behaviour.
+
+    For now, ignore unknown fields. In the future, we could change this so that unknown
+    config values cause a ValidationError, provided the error messages are meaningful to
+    server operators.
+
+    Subclassing in this way is recommended by
+    https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
+    """
+
+    class Config:
+        # By default, ignore fields that we don't recognise.
+        extra = Extra.ignore
+        # By default, don't allow fields to be reassigned after parsing.
+        allow_mutation = False
+
+
+class InstanceLocationConfig(ConfigModel):
     """The host and port to talk to an instance via HTTP replication."""
 
-    host: str
-    port: int
-    tls: bool = False
+    host: StrictStr
+    port: StrictInt
+    tls: StrictBool = False
+
+    def scheme(self) -> str:
+        """Hardcode a retrievable scheme based on self.tls"""
+        return "https" if self.tls else "http"
+
+    def netloc(self) -> str:
+        """Nicely format the network location data"""
+        return f"{self.host}:{self.port}"
 
 
 @attr.s
@@ -183,10 +215,12 @@ class WorkerConfig(Config):
         )
 
         # A map from instance name to host/port of their HTTP replication endpoint.
-        instance_map = config.get("instance_map") or {}
-        self.instance_map = {
-            name: InstanceLocationConfig(**c) for name, c in instance_map.items()
-        }
+        self.instance_map: Dict[
+            str, InstanceLocationConfig
+        ] = parse_and_validate_mapping(
+            config.get("instance_map", {}),
+            InstanceLocationConfig,
+        )
 
         # Map from type of streams to source, c.f. WriterLocations.
         writers = config.get("stream_writers") or {}