diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index d3a4b484ab..dfc5d12210 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterable
+from typing import Any, Dict, Iterable, Type, TypeVar
import jsonschema
+from pydantic import BaseModel, ValidationError, parse_obj_as
from synapse.config._base import ConfigError
from synapse.types import JsonDict
@@ -64,3 +65,28 @@ def json_error_to_config_error(
else:
path.append(str(p))
return ConfigError(e.message, path)
+
+
+Model = TypeVar("Model", bound=BaseModel)
+
+
+def parse_and_validate_mapping(
+ config: Any,
+ model_type: Type[Model],
+) -> Dict[str, Model]:
+ """Parse `config` as a mapping from strings to a given `Model` type.
+ Args:
+ config: The configuration data to check
+ model_type: The BaseModel to validate and parse against.
+ Returns:
+ Fully validated and parsed Dict[str, Model].
+ Raises:
+ ConfigError, if given improper input.
+ """
+ try:
+ # type-ignore: mypy doesn't like constructing `Dict[str, model_type]` because
+ # `model_type` is a runtime variable. Pydantic is fine with this.
+ instances = parse_obj_as(Dict[str, model_type], config) # type: ignore[valid-type]
+ except ValidationError as e:
+ raise ConfigError(str(e)) from e
+ return instances
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 1dfbe27e89..95b4047f1d 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -18,6 +18,7 @@ import logging
from typing import Any, Dict, List, Union
import attr
+from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
from synapse.config._base import (
Config,
@@ -25,6 +26,7 @@ from synapse.config._base import (
RoutableShardedWorkerHandlingConfig,
ShardedWorkerHandlingConfig,
)
+from synapse.config._util import parse_and_validate_mapping
from synapse.config.server import (
DIRECT_TCP_ERROR,
TCPListenerConfig,
@@ -50,13 +52,43 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
return obj
-@attr.s(auto_attribs=True)
-class InstanceLocationConfig:
+class ConfigModel(BaseModel):
+ """A custom version of Pydantic's BaseModel which
+
+ - ignores unknown fields and
+ - does not allow fields to be overwritten after construction,
+
+ but otherwise uses Pydantic's default behaviour.
+
+ For now, ignore unknown fields. In the future, we could change this so that unknown
+ config values cause a ValidationError, provided the error messages are meaningful to
+ server operators.
+
+ Subclassing in this way is recommended by
+ https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
+ """
+
+ class Config:
+ # By default, ignore fields that we don't recognise.
+ extra = Extra.ignore
+ # By default, don't allow fields to be reassigned after parsing.
+ allow_mutation = False
+
+
+class InstanceLocationConfig(ConfigModel):
"""The host and port to talk to an instance via HTTP replication."""
- host: str
- port: int
- tls: bool = False
+ host: StrictStr
+ port: StrictInt
+ tls: StrictBool = False
+
+ def scheme(self) -> str:
+ """Hardcode a retrievable scheme based on self.tls"""
+ return "https" if self.tls else "http"
+
+ def netloc(self) -> str:
+ """Nicely format the network location data"""
+ return f"{self.host}:{self.port}"
@attr.s
@@ -183,10 +215,12 @@ class WorkerConfig(Config):
)
# A map from instance name to host/port of their HTTP replication endpoint.
- instance_map = config.get("instance_map") or {}
- self.instance_map = {
- name: InstanceLocationConfig(**c) for name, c in instance_map.items()
- }
+ self.instance_map: Dict[
+ str, InstanceLocationConfig
+ ] = parse_and_validate_mapping(
+ config.get("instance_map", {}),
+ InstanceLocationConfig,
+ )
# Map from type of streams to source, c.f. WriterLocations.
writers = config.get("stream_writers") or {}
|