summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15578.misc1
-rw-r--r--synapse/http/client.py1
-rw-r--r--synapse/http/replicationagent.py34
-rw-r--r--synapse/replication/http/_base.py18
4 files changed, 35 insertions, 19 deletions
diff --git a/changelog.d/15578.misc b/changelog.d/15578.misc
new file mode 100644
index 0000000000..a54422239b
--- /dev/null
+++ b/changelog.d/15578.misc
@@ -0,0 +1 @@
+Allow connecting to HTTP Replication Endpoints by using `worker_name` when constructing the request.
diff --git a/synapse/http/client.py b/synapse/http/client.py
index f1ab7a8bc9..09ea93e10d 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -835,6 +835,7 @@ class ReplicationClient(BaseHttpClient):
 
         self.agent: IAgent = ReplicationAgent(
             hs.get_reactor(),
+            hs.config.worker.instance_map,
             contextFactory=hs.get_http_client_context_factory(),
             pool=pool,
         )
diff --git a/synapse/http/replicationagent.py b/synapse/http/replicationagent.py
index 5ecd08be0f..800f21873d 100644
--- a/synapse/http/replicationagent.py
+++ b/synapse/http/replicationagent.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import Optional
+from typing import Dict, Optional
 
 from zope.interface import implementer
 
@@ -32,6 +32,7 @@ from twisted.web.iweb import (
     IResponse,
 )
 
+from synapse.config.workers import InstanceLocationConfig
 from synapse.types import ISynapseReactor
 
 logger = logging.getLogger(__name__)
@@ -44,9 +45,11 @@ class ReplicationEndpointFactory:
     def __init__(
         self,
         reactor: ISynapseReactor,
+        instance_map: Dict[str, InstanceLocationConfig],
         context_factory: IPolicyForHTTPS,
     ) -> None:
         self.reactor = reactor
+        self.instance_map = instance_map
         self.context_factory = context_factory
 
     def endpointForURI(self, uri: URI) -> IStreamClientEndpoint:
@@ -58,15 +61,29 @@ class ReplicationEndpointFactory:
 
         Returns: The correct client endpoint object
         """
-        if uri.scheme in (b"http", b"https"):
-            endpoint = HostnameEndpoint(self.reactor, uri.host, uri.port)
-            if uri.scheme == b"https":
+        # The given URI has a special scheme and includes the worker name. The
+        # actual connection details are pulled from the instance map.
+        worker_name = uri.netloc.decode("utf-8")
+        scheme = self.instance_map[worker_name].scheme()
+
+        if scheme in ("http", "https"):
+            endpoint = HostnameEndpoint(
+                self.reactor,
+                self.instance_map[worker_name].host,
+                self.instance_map[worker_name].port,
+            )
+            if scheme == "https":
                 endpoint = wrapClientTLS(
-                    self.context_factory.creatorForNetloc(uri.host, uri.port), endpoint
+                    # The 'port' argument below isn't actually used by the function
+                    self.context_factory.creatorForNetloc(
+                        self.instance_map[worker_name].host,
+                        self.instance_map[worker_name].port,
+                    ),
+                    endpoint,
                 )
             return endpoint
         else:
-            raise SchemeNotSupported(f"Unsupported scheme: {uri.scheme!r}")
+            raise SchemeNotSupported(f"Unsupported scheme: {scheme}")
 
 
 @implementer(IAgent)
@@ -80,6 +97,7 @@ class ReplicationAgent(_AgentBase):
     def __init__(
         self,
         reactor: ISynapseReactor,
+        instance_map: Dict[str, InstanceLocationConfig],
         contextFactory: IPolicyForHTTPS,
         connectTimeout: Optional[float] = None,
         bindAddress: Optional[bytes] = None,
@@ -102,7 +120,9 @@ class ReplicationAgent(_AgentBase):
                 created.
         """
         _AgentBase.__init__(self, reactor, pool)
-        endpoint_factory = ReplicationEndpointFactory(reactor, contextFactory)
+        endpoint_factory = ReplicationEndpointFactory(
+            reactor, instance_map, contextFactory
+        )
         self._endpointFactory = endpoint_factory
 
     def request(
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index dc7820f963..63cf24a14d 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -219,11 +219,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             with outgoing_gauge.track_inprogress():
                 if instance_name == local_instance_name:
                     raise Exception("Trying to send HTTP request to self")
-                if instance_name in instance_map:
-                    host = instance_map[instance_name].host
-                    port = instance_map[instance_name].port
-                    tls = instance_map[instance_name].tls
-                else:
+                if instance_name not in instance_map:
                     raise Exception(
                         "Instance %r not in 'instance_map' config" % (instance_name,)
                     )
@@ -271,13 +267,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                         "Unknown METHOD on %s replication endpoint" % (cls.NAME,)
                     )
 
-                # Here the protocol is hard coded to be http by default or https in case the replication
-                # port is set to have tls true.
-                scheme = "https" if tls else "http"
-                uri = "%s://%s:%s/_synapse/replication/%s/%s" % (
-                    scheme,
-                    host,
-                    port,
+                # Hard code a special scheme to show this only used for replication. The
+                # instance_name will be passed into the ReplicationEndpointFactory to
+                # determine connection details from the instance_map.
+                uri = "synapse-replication://%s/_synapse/replication/%s/%s" % (
+                    instance_name,
                     cls.NAME,
                     "/".join(url_args),
                 )