diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 03560c1f0e..793cef6c26 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,6 +16,8 @@
import abc
import logging
import re
+from inspect import signature
+from typing import Dict, List, Tuple
from six import raise_from
from six.moves import urllib
@@ -43,7 +45,7 @@ class ReplicationEndpoint(object):
"""Helper base class for defining new replication HTTP endpoints.
This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
- (with an `/:txn_id` prefix for cached requests.), where NAME is a name,
+ (with a `/:txn_id` suffix for cached requests), where NAME is a name,
PATH_ARGS are a tuple of parameters to be encoded in the URL.
For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`,
@@ -59,6 +61,8 @@ class ReplicationEndpoint(object):
must call `register` to register the path with the HTTP server.
Requests can be sent by calling the client returned by `make_client`.
+ Requests are sent to master process by default, but can be sent to other
+ named processes by specifying an `instance_name` keyword argument.
Attributes:
NAME (str): A name for the endpoint, added to the path as well as used
@@ -78,9 +82,8 @@ class ReplicationEndpoint(object):
__metaclass__ = abc.ABCMeta
- NAME = abc.abstractproperty()
- PATH_ARGS = abc.abstractproperty()
-
+ NAME = abc.abstractproperty() # type: str # type: ignore
+ PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
METHOD = "POST"
CACHE = True
RETRY_ON_TIMEOUT = True
@@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
)
+ # We reserve `instance_name` as a parameter to sending requests, so we
+ # assert here that sub classes don't try and use the name.
+ assert (
+ "instance_name" not in self.PATH_ARGS
+ ), "`instance_name` is a reserved paramater name"
+ assert (
+ "instance_name"
+ not in signature(self.__class__._serialize_payload).parameters
+ ), "`instance_name` is a reserved paramater name"
+
assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod
@@ -110,14 +123,14 @@ class ReplicationEndpoint(object):
return {}
@abc.abstractmethod
- def _handle_request(self, request, **kwargs):
+ async def _handle_request(self, request, **kwargs):
"""Handle incoming request.
This is called with the request object and PATH_ARGS.
Returns:
- Deferred[dict]: A JSON serialisable dict to be used as response
- body of request.
+ tuple[int, dict]: HTTP status code and a JSON serialisable dict
+ to be used as response body of request.
"""
pass
@@ -128,14 +141,30 @@ class ReplicationEndpoint(object):
Returns a callable that accepts the same parameters as `_serialize_payload`.
"""
clock = hs.get_clock()
- host = hs.config.worker_replication_host
- port = hs.config.worker_replication_http_port
-
client = hs.get_simple_http_client()
+ local_instance_name = hs.get_instance_name()
+
+ master_host = hs.config.worker_replication_host
+ master_port = hs.config.worker_replication_http_port
+
+ instance_map = hs.config.worker.instance_map
@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
- def send_request(**kwargs):
+ def send_request(instance_name="master", **kwargs):
+ if instance_name == local_instance_name:
+ raise Exception("Trying to send HTTP request to self")
+ if instance_name == "master":
+ host = master_host
+ port = master_port
+ elif instance_name in instance_map:
+ host = instance_map[instance_name].host
+ port = instance_map[instance_name].port
+ else:
+ raise Exception(
+ "Instance %r not in 'instance_map' config" % (instance_name,)
+ )
+
data = yield cls._serialize_payload(**kwargs)
url_args = [
@@ -171,7 +200,7 @@ class ReplicationEndpoint(object):
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
- headers = {}
+ headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = yield request_func(uri, data, headers=headers)
@@ -180,7 +209,7 @@ class ReplicationEndpoint(object):
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
raise
- logger.warn("%s request timed out", cls.NAME)
+ logger.warning("%s request timed out", cls.NAME)
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
@@ -207,7 +236,7 @@ class ReplicationEndpoint(object):
method = self.METHOD
if self.CACHE:
- handler = self._cached_handler
+ handler = self._cached_handler # type: ignore
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
|