summary refs log tree commit diff
path: root/synapse/replication/http/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/http/_base.py')
-rw-r--r--synapse/replication/http/_base.py97
1 files changed, 88 insertions, 9 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 3f4d3fc51a..709327b97f 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -17,7 +17,7 @@ import logging
 import re
 import urllib.parse
 from inspect import signature
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple
 
 from prometheus_client import Counter, Gauge
 
@@ -27,6 +27,7 @@ from twisted.web.server import Request
 from synapse.api.errors import HttpResponseException, SynapseError
 from synapse.http import RequestTimedOutError
 from synapse.http.server import HttpServer
+from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing
 from synapse.logging.opentracing import trace_with_opname
@@ -53,6 +54,9 @@ _outgoing_request_counter = Counter(
 )
 
 
+_STREAM_POSITION_KEY = "_INT_STREAM_POS"
+
+
 class ReplicationEndpoint(metaclass=abc.ABCMeta):
     """Helper base class for defining new replication HTTP endpoints.
 
@@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             a connection error is received.
         RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when
             receiving connection errors, each will backoff exponentially longer.
+        WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to
+            catch up before processing the request and/or response. Defaults to
+            True.
     """
 
     NAME: str = abc.abstractproperty()  # type: ignore
@@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
     RETRY_ON_CONNECT_ERROR = True
     RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5  # =63s (2^6-1)
 
+    WAIT_FOR_STREAMS: ClassVar[bool] = True
+
     def __init__(self, hs: "HomeServer"):
         if self.CACHE:
             self.response_cache: ResponseCache[str] = ResponseCache(
@@ -126,6 +135,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         if hs.config.worker.worker_replication_secret:
             self._replication_secret = hs.config.worker.worker_replication_secret
 
+        self._streams = hs.get_replication_command_handler().get_streams_to_replicate()
+        self._replication = hs.get_replication_data_handler()
+        self._instance_name = hs.get_instance_name()
+
     def _check_auth(self, request: Request) -> None:
         # Get the authorization header.
         auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
@@ -160,7 +173,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
     async def _handle_request(
-        self, request: Request, **kwargs: Any
+        self, request: Request, content: JsonDict, **kwargs: Any
     ) -> Tuple[int, JsonDict]:
         """Handle incoming request.
 
@@ -201,6 +214,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         @trace_with_opname("outgoing_replication_request")
         async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
+            # We have to pull these out here to avoid circular dependencies...
+            streams = hs.get_replication_command_handler().get_streams_to_replicate()
+            replication = hs.get_replication_data_handler()
+
             with outgoing_gauge.track_inprogress():
                 if instance_name == local_instance_name:
                     raise Exception("Trying to send HTTP request to self")
@@ -219,6 +236,24 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
                 data = await cls._serialize_payload(**kwargs)
 
+                if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS:
+                    # Include the current stream positions that we write to. We
+                    # don't do this for GETs as they don't have a body, and we
+                    # generally assume that a GET won't rely on data we have
+                    # written.
+                    if _STREAM_POSITION_KEY in data:
+                        raise Exception(
+                            "data to send contains %r key", _STREAM_POSITION_KEY
+                        )
+
+                    data[_STREAM_POSITION_KEY] = {
+                        "streams": {
+                            stream.NAME: stream.current_token(local_instance_name)
+                            for stream in streams
+                        },
+                        "instance_name": local_instance_name,
+                    }
+
                 url_args = [
                     urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
                 ]
@@ -308,6 +343,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                     ) from e
 
                 _outgoing_request_counter.labels(cls.NAME, 200).inc()
+
+                # Wait on any streams that the remote may have written to.
+                for stream_name, position in result.get(
+                    _STREAM_POSITION_KEY, {}
+                ).items():
+                    await replication.wait_for_stream_position(
+                        instance_name=instance_name,
+                        stream_name=stream_name,
+                        position=position,
+                        raise_on_timeout=False,
+                    )
+
                 return result
 
         return send_request
@@ -353,6 +400,23 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         if self._replication_secret:
             self._check_auth(request)
 
+        if self.METHOD == "GET":
+            # GET APIs always have an empty body.
+            content = {}
+        else:
+            content = parse_json_object_from_request(request)
+
+        # Wait on any streams that the remote may have written to.
+        for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[
+            "streams"
+        ].items():
+            await self._replication.wait_for_stream_position(
+                instance_name=content[_STREAM_POSITION_KEY]["instance_name"],
+                stream_name=stream_name,
+                position=position,
+                raise_on_timeout=False,
+            )
+
         if self.CACHE:
             txn_id = kwargs.pop("txn_id")
 
@@ -361,13 +425,28 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             # correctly yet. In particular, there may be issues to do with logging
             # context lifetimes.
 
-            return await self.response_cache.wrap(
-                txn_id, self._handle_request, request, **kwargs
+            code, response = await self.response_cache.wrap(
+                txn_id, self._handle_request, request, content, **kwargs
             )
+        else:
+            # The `@cancellable` decorator may be applied to `_handle_request`. But we
+            # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
+            # so we have to set up the cancellable flag ourselves.
+            request.is_render_cancellable = is_function_cancellable(
+                self._handle_request
+            )
+
+            code, response = await self._handle_request(request, content, **kwargs)
+
+        # Return streams we may have written to in the course of processing this
+        # request.
+        if _STREAM_POSITION_KEY in response:
+            raise Exception("data to send contains %r key", _STREAM_POSITION_KEY)
 
-        # The `@cancellable` decorator may be applied to `_handle_request`. But we
-        # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
-        # so we have to set up the cancellable flag ourselves.
-        request.is_render_cancellable = is_function_cancellable(self._handle_request)
+        if self.WAIT_FOR_STREAMS:
+            response[_STREAM_POSITION_KEY] = {
+                stream.NAME: stream.current_token(self._instance_name)
+                for stream in self._streams
+            }
 
-        return await self._handle_request(request, **kwargs)
+        return code, response