diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/config/workers.py | 30 | ||||
-rw-r--r-- | synapse/replication/http/_base.py | 23 | ||||
-rw-r--r-- | synapse/replication/http/endpoint_factory.py | 67 |
3 files changed, 91 insertions, 29 deletions
diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 1dfbe27e89..2c96f4bda4 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -18,6 +18,7 @@ import logging from typing import Any, Dict, List, Union import attr +from pydantic import BaseModel, StrictBool, StrictInt, StrictStr, parse_obj_as from synapse.config._base import ( Config, @@ -50,13 +51,23 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: return obj -@attr.s(auto_attribs=True) -class InstanceLocationConfig: +class UnixSocketInstanceLocationConfig(BaseModel): + """The path to talk to an instance via HTTP replication over Unix socket.""" + + socket_path: StrictStr + + +class TcpInstanceLocationConfig(BaseModel): """The host and port to talk to an instance via HTTP replication.""" - host: str - port: int - tls: bool = False + host: StrictStr + port: StrictInt + tls: StrictBool = False + + +InstanceLocationConfig = Union[ + UnixSocketInstanceLocationConfig, TcpInstanceLocationConfig +] @attr.s @@ -182,11 +193,10 @@ class WorkerConfig(Config): federation_sender_instances ) - # A map from instance name to host/port of their HTTP replication endpoint. - instance_map = config.get("instance_map") or {} - self.instance_map = { - name: InstanceLocationConfig(**c) for name, c in instance_map.items() - } + # A map from instance name to connection details for their HTTP replication endpoint. + self.instance_map = parse_obj_as( + Dict[str, InstanceLocationConfig], config.get("instance_map") or {} + ) # Map from type of streams to source, c.f. WriterLocations. writers = config.get("stream_writers") or {} diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 8c2c54c07a..354a907cec 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -198,9 +198,6 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): local_instance_name = hs.get_instance_name() # The value of these option should match the replication listener settings - master_host = hs.config.worker.worker_replication_host - master_port = hs.config.worker.worker_replication_http_port - master_tls = hs.config.worker.worker_replication_http_tls instance_map = hs.config.worker.instance_map @@ -221,15 +218,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): with outgoing_gauge.track_inprogress(): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") - if instance_name == "master": - host = master_host - port = master_port - tls = master_tls - elif instance_name in instance_map: - host = instance_map[instance_name].host - port = instance_map[instance_name].port - tls = instance_map[instance_name].tls - else: + if instance_name not in instance_map: raise Exception( "Instance %r not in 'instance_map' config" % (instance_name,) ) @@ -279,14 +268,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # Here the protocol is hard coded to be http by default or https in case the replication # port is set to have tls true. + tls = False # TODO scheme = "https" if tls else "http" - uri = "%s://%s:%s/_synapse/replication/%s/%s" % ( - scheme, - host, - port, - cls.NAME, - "/".join(url_args), - ) + joined_args = "/".join(url_args) + uri = f"{scheme}://{instance_name}/_synapse/replication/{cls.NAME}/{joined_args}" headers: Dict[bytes, List[bytes]] = {} # Add an authorization header, if configured. diff --git a/synapse/replication/http/endpoint_factory.py b/synapse/replication/http/endpoint_factory.py new file mode 100644 index 0000000000..87ce468cf5 --- /dev/null +++ b/synapse/replication/http/endpoint_factory.py @@ -0,0 +1,67 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +from zope.interface import implementer + +from twisted.internet.endpoints import UNIXClientEndpoint +from twisted.internet.interfaces import IStreamClientEndpoint +from twisted.web.client import URI +from twisted.web.iweb import IAgentEndpointFactory + +from synapse.config.workers import ( + InstanceLocationConfig, + TcpInstanceLocationConfig, + UnixSocketInstanceLocationConfig, +) +from synapse.types import ISynapseReactor + + +@implementer(IAgentEndpointFactory) +class WorkerEndpointFactory: + def __init__( + self, + reactor: ISynapseReactor, + configs: Dict[str, InstanceLocationConfig], + tcp_endpoint_factory: IAgentEndpointFactory, + ): + self.reactor = reactor + self.configs = configs + self.tcp_agent_factory = tcp_endpoint_factory + + def endpointForURI(self, uri: URI) -> IStreamClientEndpoint: + worker_config = self.configs.get(uri.host) + if not worker_config: + raise ValueError(f"Don't know how to connect to worker: {uri.host}") + + if isinstance(worker_config, TcpInstanceLocationConfig): + # TODO TLS support + rewritten_uri = URI( + scheme=uri.scheme, + # TODO I'd probably cache the encoded netloc and host in the TCP Config? + netloc=f"{worker_config.host}:{worker_config.port}".encode("utf-8"), + host=worker_config.host.encode("utf-8"), + port=worker_config.port, + path=uri.path, + params=uri.params, + query=uri.query, + fragment=uri.fragment, + ) + return self.tcp_agent_factory.endpointForURI(rewritten_uri) + elif isinstance(worker_config, UnixSocketInstanceLocationConfig): + return UNIXClientEndpoint(self.reactor, worker_config.socket_path) + else: + raise ValueError( + f"Unknown worker connection config {worker_config} for {uri.host}" + ) |