summary refs log tree commit diff
path: root/synapse/replication/http/_base.py
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2020-06-10 11:42:30 +0100
committerBrendan Abolivier <babolivier@matrix.org>2020-06-10 11:42:30 +0100
commitec0a7b9034806d6b2ba086bae58f5c6b0fd14672 (patch)
treef2af547b1342795e10548f8fb7a9cfc93e03df37 /synapse/replication/http/_base.py
parentchangelog (diff)
parent1.15.0rc1 (diff)
downloadsynapse-ec0a7b9034806d6b2ba086bae58f5c6b0fd14672.tar.xz
Merge branch 'develop' into babolivier/mark_unread
Diffstat (limited to 'synapse/replication/http/_base.py')
-rw-r--r--synapse/replication/http/_base.py57
1 files changed, 43 insertions, 14 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 03560c1f0e..793cef6c26 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,6 +16,8 @@
 import abc
 import logging
 import re
+from inspect import signature
+from typing import Dict, List, Tuple
 
 from six import raise_from
 from six.moves import urllib
@@ -43,7 +45,7 @@ class ReplicationEndpoint(object):
     """Helper base class for defining new replication HTTP endpoints.
 
     This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
-    (with an `/:txn_id` prefix for cached requests.), where NAME is a name,
+    (with a `/:txn_id` suffix for cached requests), where NAME is a name,
     PATH_ARGS are a tuple of parameters to be encoded in the URL.
 
     For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`,
@@ -59,6 +61,8 @@ class ReplicationEndpoint(object):
     must call `register` to register the path with the HTTP server.
 
     Requests can be sent by calling the client returned by `make_client`.
+    Requests are sent to master process by default, but can be sent to other
+    named processes by specifying an `instance_name` keyword argument.
 
     Attributes:
         NAME (str): A name for the endpoint, added to the path as well as used
@@ -78,9 +82,8 @@ class ReplicationEndpoint(object):
 
     __metaclass__ = abc.ABCMeta
 
-    NAME = abc.abstractproperty()
-    PATH_ARGS = abc.abstractproperty()
-
+    NAME = abc.abstractproperty()  # type: str  # type: ignore
+    PATH_ARGS = abc.abstractproperty()  # type: Tuple[str, ...]  # type: ignore
     METHOD = "POST"
     CACHE = True
     RETRY_ON_TIMEOUT = True
@@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
                 hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
             )
 
+        # We reserve `instance_name` as a parameter to sending requests, so we
+        # assert here that sub classes don't try and use the name.
+        assert (
+            "instance_name" not in self.PATH_ARGS
+        ), "`instance_name` is a reserved paramater name"
+        assert (
+            "instance_name"
+            not in signature(self.__class__._serialize_payload).parameters
+        ), "`instance_name` is a reserved paramater name"
+
         assert self.METHOD in ("PUT", "POST", "GET")
 
     @abc.abstractmethod
@@ -110,14 +123,14 @@ class ReplicationEndpoint(object):
         return {}
 
     @abc.abstractmethod
-    def _handle_request(self, request, **kwargs):
+    async def _handle_request(self, request, **kwargs):
         """Handle incoming request.
 
         This is called with the request object and PATH_ARGS.
 
         Returns:
-            Deferred[dict]: A JSON serialisable dict to be used as response
-            body of request.
+            tuple[int, dict]: HTTP status code and a JSON serialisable dict
+            to be used as response body of request.
         """
         pass
 
@@ -128,14 +141,30 @@ class ReplicationEndpoint(object):
         Returns a callable that accepts the same parameters as `_serialize_payload`.
         """
         clock = hs.get_clock()
-        host = hs.config.worker_replication_host
-        port = hs.config.worker_replication_http_port
-
         client = hs.get_simple_http_client()
+        local_instance_name = hs.get_instance_name()
+
+        master_host = hs.config.worker_replication_host
+        master_port = hs.config.worker_replication_http_port
+
+        instance_map = hs.config.worker.instance_map
 
         @trace(opname="outgoing_replication_request")
         @defer.inlineCallbacks
-        def send_request(**kwargs):
+        def send_request(instance_name="master", **kwargs):
+            if instance_name == local_instance_name:
+                raise Exception("Trying to send HTTP request to self")
+            if instance_name == "master":
+                host = master_host
+                port = master_port
+            elif instance_name in instance_map:
+                host = instance_map[instance_name].host
+                port = instance_map[instance_name].port
+            else:
+                raise Exception(
+                    "Instance %r not in 'instance_map' config" % (instance_name,)
+                )
+
             data = yield cls._serialize_payload(**kwargs)
 
             url_args = [
@@ -171,7 +200,7 @@ class ReplicationEndpoint(object):
                 # have a good idea that the request has either succeeded or failed on
                 # the master, and so whether we should clean up or not.
                 while True:
-                    headers = {}
+                    headers = {}  # type: Dict[bytes, List[bytes]]
                     inject_active_span_byte_dict(headers, None, check_destination=False)
                     try:
                         result = yield request_func(uri, data, headers=headers)
@@ -180,7 +209,7 @@ class ReplicationEndpoint(object):
                         if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
                             raise
 
-                    logger.warn("%s request timed out", cls.NAME)
+                    logger.warning("%s request timed out", cls.NAME)
 
                     # If we timed out we probably don't need to worry about backing
                     # off too much, but lets just wait a little anyway.
@@ -207,7 +236,7 @@ class ReplicationEndpoint(object):
         method = self.METHOD
 
         if self.CACHE:
-            handler = self._cached_handler
+            handler = self._cached_handler  # type: ignore
             url_args.append("txn_id")
 
         args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)