summary refs log tree commit diff
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2023-04-05 12:37:38 +0100
committerOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2023-04-05 12:37:38 +0100
commitd7a1948a47683d6e01d61bbe295029357b931e82 (patch)
tree4799905078d845b4e79a17a974b04fcf375f4c9f
parentDelete server-side backup keys when deactivating an account. (#15181) (diff)
downloadsynapse-rei/worker_endpoint_factory.tar.xz
Initial crack at defining a worker agent endpoint factory github/rei/worker_endpoint_factory rei/worker_endpoint_factory
TODO it's not being configured in the HTTP Client yet
-rw-r--r--synapse/config/workers.py30
-rw-r--r--synapse/replication/http/_base.py23
-rw-r--r--synapse/replication/http/endpoint_factory.py67
-rw-r--r--tests/replication/_base.py2
4 files changed, 93 insertions, 29 deletions
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 1dfbe27e89..2c96f4bda4 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -18,6 +18,7 @@ import logging
 from typing import Any, Dict, List, Union
 
 import attr
+from pydantic import BaseModel, StrictBool, StrictInt, StrictStr, parse_obj_as
 
 from synapse.config._base import (
     Config,
@@ -50,13 +51,23 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
     return obj
 
 
-@attr.s(auto_attribs=True)
-class InstanceLocationConfig:
+class UnixSocketInstanceLocationConfig(BaseModel):
+    """The path to talk to an instance via HTTP replication over Unix socket."""
+
+    socket_path: StrictStr
+
+
+class TcpInstanceLocationConfig(BaseModel):
     """The host and port to talk to an instance via HTTP replication."""
 
-    host: str
-    port: int
-    tls: bool = False
+    host: StrictStr
+    port: StrictInt
+    tls: StrictBool = False
+
+
+InstanceLocationConfig = Union[
+    UnixSocketInstanceLocationConfig, TcpInstanceLocationConfig
+]
 
 
 @attr.s
@@ -182,11 +193,10 @@ class WorkerConfig(Config):
             federation_sender_instances
         )
 
-        # A map from instance name to host/port of their HTTP replication endpoint.
-        instance_map = config.get("instance_map") or {}
-        self.instance_map = {
-            name: InstanceLocationConfig(**c) for name, c in instance_map.items()
-        }
+        # A map from instance name to connection details for their HTTP replication endpoint.
+        self.instance_map = parse_obj_as(
+            Dict[str, InstanceLocationConfig], config.get("instance_map") or {}
+        )
 
         # Map from type of streams to source, c.f. WriterLocations.
         writers = config.get("stream_writers") or {}
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 8c2c54c07a..354a907cec 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -198,9 +198,6 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         local_instance_name = hs.get_instance_name()
 
         # The value of these option should match the replication listener settings
-        master_host = hs.config.worker.worker_replication_host
-        master_port = hs.config.worker.worker_replication_http_port
-        master_tls = hs.config.worker.worker_replication_http_tls
 
         instance_map = hs.config.worker.instance_map
 
@@ -221,15 +218,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             with outgoing_gauge.track_inprogress():
                 if instance_name == local_instance_name:
                     raise Exception("Trying to send HTTP request to self")
-                if instance_name == "master":
-                    host = master_host
-                    port = master_port
-                    tls = master_tls
-                elif instance_name in instance_map:
-                    host = instance_map[instance_name].host
-                    port = instance_map[instance_name].port
-                    tls = instance_map[instance_name].tls
-                else:
+                if instance_name not in instance_map:
                     raise Exception(
                         "Instance %r not in 'instance_map' config" % (instance_name,)
                     )
@@ -279,14 +268,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
                 # Here the protocol is hard coded to be http by default or https in case the replication
                 # port is set to have tls true.
+                tls = False  # TODO
                 scheme = "https" if tls else "http"
-                uri = "%s://%s:%s/_synapse/replication/%s/%s" % (
-                    scheme,
-                    host,
-                    port,
-                    cls.NAME,
-                    "/".join(url_args),
-                )
+                joined_args = "/".join(url_args)
+                uri = f"{scheme}://{instance_name}/_synapse/replication/{cls.NAME}/{joined_args}"
 
                 headers: Dict[bytes, List[bytes]] = {}
                 # Add an authorization header, if configured.
diff --git a/synapse/replication/http/endpoint_factory.py b/synapse/replication/http/endpoint_factory.py
new file mode 100644
index 0000000000..87ce468cf5
--- /dev/null
+++ b/synapse/replication/http/endpoint_factory.py
@@ -0,0 +1,67 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict
+
+from zope.interface import implementer
+
+from twisted.internet.endpoints import UNIXClientEndpoint
+from twisted.internet.interfaces import IStreamClientEndpoint
+from twisted.web.client import URI
+from twisted.web.iweb import IAgentEndpointFactory
+
+from synapse.config.workers import (
+    InstanceLocationConfig,
+    TcpInstanceLocationConfig,
+    UnixSocketInstanceLocationConfig,
+)
+from synapse.types import ISynapseReactor
+
+
+@implementer(IAgentEndpointFactory)
+class WorkerEndpointFactory:
+    def __init__(
+        self,
+        reactor: ISynapseReactor,
+        configs: Dict[str, InstanceLocationConfig],
+        tcp_endpoint_factory: IAgentEndpointFactory,
+    ):
+        self.reactor = reactor
+        self.configs = configs
+        self.tcp_agent_factory = tcp_endpoint_factory
+
+    def endpointForURI(self, uri: URI) -> IStreamClientEndpoint:
+        worker_config = self.configs.get(uri.host)
+        if not worker_config:
+            raise ValueError(f"Don't know how to connect to worker: {uri.host}")
+
+        if isinstance(worker_config, TcpInstanceLocationConfig):
+            # TODO TLS support
+            rewritten_uri = URI(
+                scheme=uri.scheme,
+                # TODO I'd probably cache the encoded netloc and host in the TCP Config?
+                netloc=f"{worker_config.host}:{worker_config.port}".encode("utf-8"),
+                host=worker_config.host.encode("utf-8"),
+                port=worker_config.port,
+                path=uri.path,
+                params=uri.params,
+                query=uri.query,
+                fragment=uri.fragment,
+            )
+            return self.tcp_agent_factory.endpointForURI(rewritten_uri)
+        elif isinstance(worker_config, UnixSocketInstanceLocationConfig):
+            return UNIXClientEndpoint(self.reactor, worker_config.socket_path)
+        else:
+            raise ValueError(
+                f"Unknown worker connection config {worker_config} for {uri.host}"
+            )
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 0f1a8a145f..0364c888ac 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import GenericWorkerServer
+from synapse.config.workers import TcpInstanceLocationConfig
 from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.client import ReplicationDataHandler
@@ -339,6 +340,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # `_handle_http_replication_attempt` like we do with the master HS.
         instance_name = worker_hs.get_instance_name()
         instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
+        assert isinstance(instance_loc, TcpInstanceLocationConfig)
         if instance_loc:
             # Ensure the host is one that has a fake DNS entry.
             if instance_loc.host not in self.reactor.lookups: