diff options
Diffstat (limited to 'synapse/replication/http/_base.py')
-rw-r--r-- | synapse/replication/http/_base.py | 57 |
1 files changed, 43 insertions, 14 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 03560c1f0e..793cef6c26 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,6 +16,8 @@ import abc import logging import re +from inspect import signature +from typing import Dict, List, Tuple from six import raise_from from six.moves import urllib @@ -43,7 +45,7 @@ class ReplicationEndpoint(object): """Helper base class for defining new replication HTTP endpoints. This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` - (with an `/:txn_id` prefix for cached requests.), where NAME is a name, + (with a `/:txn_id` suffix for cached requests), where NAME is a name, PATH_ARGS are a tuple of parameters to be encoded in the URL. For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`, @@ -59,6 +61,8 @@ class ReplicationEndpoint(object): must call `register` to register the path with the HTTP server. Requests can be sent by calling the client returned by `make_client`. + Requests are sent to master process by default, but can be sent to other + named processes by specifying an `instance_name` keyword argument. Attributes: NAME (str): A name for the endpoint, added to the path as well as used @@ -78,9 +82,8 @@ class ReplicationEndpoint(object): __metaclass__ = abc.ABCMeta - NAME = abc.abstractproperty() - PATH_ARGS = abc.abstractproperty() - + NAME = abc.abstractproperty() # type: str # type: ignore + PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore METHOD = "POST" CACHE = True RETRY_ON_TIMEOUT = True @@ -91,6 +94,16 @@ class ReplicationEndpoint(object): hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) + # We reserve `instance_name` as a parameter to sending requests, so we + # assert here that sub classes don't try and use the name. + assert ( + "instance_name" not in self.PATH_ARGS + ), "`instance_name` is a reserved paramater name" + assert ( + "instance_name" + not in signature(self.__class__._serialize_payload).parameters + ), "`instance_name` is a reserved paramater name" + assert self.METHOD in ("PUT", "POST", "GET") @abc.abstractmethod @@ -110,14 +123,14 @@ class ReplicationEndpoint(object): return {} @abc.abstractmethod - def _handle_request(self, request, **kwargs): + async def _handle_request(self, request, **kwargs): """Handle incoming request. This is called with the request object and PATH_ARGS. Returns: - Deferred[dict]: A JSON serialisable dict to be used as response - body of request. + tuple[int, dict]: HTTP status code and a JSON serialisable dict + to be used as response body of request. """ pass @@ -128,14 +141,30 @@ class ReplicationEndpoint(object): Returns a callable that accepts the same parameters as `_serialize_payload`. """ clock = hs.get_clock() - host = hs.config.worker_replication_host - port = hs.config.worker_replication_http_port - client = hs.get_simple_http_client() + local_instance_name = hs.get_instance_name() + + master_host = hs.config.worker_replication_host + master_port = hs.config.worker_replication_http_port + + instance_map = hs.config.worker.instance_map @trace(opname="outgoing_replication_request") @defer.inlineCallbacks - def send_request(**kwargs): + def send_request(instance_name="master", **kwargs): + if instance_name == local_instance_name: + raise Exception("Trying to send HTTP request to self") + if instance_name == "master": + host = master_host + port = master_port + elif instance_name in instance_map: + host = instance_map[instance_name].host + port = instance_map[instance_name].port + else: + raise Exception( + "Instance %r not in 'instance_map' config" % (instance_name,) + ) + data = yield cls._serialize_payload(**kwargs) url_args = [ @@ -171,7 +200,7 @@ class ReplicationEndpoint(object): # have a good idea that the request has either succeeded or failed on # the master, and so whether we should clean up or not. while True: - headers = {} + headers = {} # type: Dict[bytes, List[bytes]] inject_active_span_byte_dict(headers, None, check_destination=False) try: result = yield request_func(uri, data, headers=headers) @@ -180,7 +209,7 @@ class ReplicationEndpoint(object): if e.code != 504 or not cls.RETRY_ON_TIMEOUT: raise - logger.warn("%s request timed out", cls.NAME) + logger.warning("%s request timed out", cls.NAME) # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. @@ -207,7 +236,7 @@ class ReplicationEndpoint(object): method = self.METHOD if self.CACHE: - handler = self._cached_handler + handler = self._cached_handler # type: ignore url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) |