diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 3f4d3fc51a..709327b97f 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -17,7 +17,7 @@ import logging
import re
import urllib.parse
from inspect import signature
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple
from prometheus_client import Counter, Gauge
@@ -27,6 +27,7 @@ from twisted.web.server import Request
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer
+from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.opentracing import trace_with_opname
@@ -53,6 +54,9 @@ _outgoing_request_counter = Counter(
)
+_STREAM_POSITION_KEY = "_INT_STREAM_POS"
+
+
class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""Helper base class for defining new replication HTTP endpoints.
@@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
a connection error is received.
RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when
receiving connection errors, each will backoff exponentially longer.
+ WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to
+ catch up before processing the request and/or response. Defaults to
+ True.
"""
NAME: str = abc.abstractproperty() # type: ignore
@@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
RETRY_ON_CONNECT_ERROR = True
RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1)
+ WAIT_FOR_STREAMS: ClassVar[bool] = True
+
def __init__(self, hs: "HomeServer"):
if self.CACHE:
self.response_cache: ResponseCache[str] = ResponseCache(
@@ -126,6 +135,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret
+ self._streams = hs.get_replication_command_handler().get_streams_to_replicate()
+ self._replication = hs.get_replication_data_handler()
+ self._instance_name = hs.get_instance_name()
+
def _check_auth(self, request: Request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
@@ -160,7 +173,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def _handle_request(
- self, request: Request, **kwargs: Any
+ self, request: Request, content: JsonDict, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Handle incoming request.
@@ -201,6 +214,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@trace_with_opname("outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
+ # We have to pull these out here to avoid circular dependencies...
+ streams = hs.get_replication_command_handler().get_streams_to_replicate()
+ replication = hs.get_replication_data_handler()
+
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
@@ -219,6 +236,24 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
data = await cls._serialize_payload(**kwargs)
+ if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS:
+ # Include the current stream positions that we write to. We
+ # don't do this for GETs as they don't have a body, and we
+ # generally assume that a GET won't rely on data we have
+ # written.
+ if _STREAM_POSITION_KEY in data:
+ raise Exception(
+ "data to send contains %r key", _STREAM_POSITION_KEY
+ )
+
+ data[_STREAM_POSITION_KEY] = {
+ "streams": {
+ stream.NAME: stream.current_token(local_instance_name)
+ for stream in streams
+ },
+ "instance_name": local_instance_name,
+ }
+
url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
]
@@ -308,6 +343,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
) from e
_outgoing_request_counter.labels(cls.NAME, 200).inc()
+
+ # Wait on any streams that the remote may have written to.
+ for stream_name, position in result.get(
+ _STREAM_POSITION_KEY, {}
+ ).items():
+ await replication.wait_for_stream_position(
+ instance_name=instance_name,
+ stream_name=stream_name,
+ position=position,
+ raise_on_timeout=False,
+ )
+
return result
return send_request
@@ -353,6 +400,23 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if self._replication_secret:
self._check_auth(request)
+ if self.METHOD == "GET":
+ # GET APIs always have an empty body.
+ content = {}
+ else:
+ content = parse_json_object_from_request(request)
+
+ # Wait on any streams that the remote may have written to.
+ for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[
+ "streams"
+ ].items():
+ await self._replication.wait_for_stream_position(
+ instance_name=content[_STREAM_POSITION_KEY]["instance_name"],
+ stream_name=stream_name,
+ position=position,
+ raise_on_timeout=False,
+ )
+
if self.CACHE:
txn_id = kwargs.pop("txn_id")
@@ -361,13 +425,28 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# correctly yet. In particular, there may be issues to do with logging
# context lifetimes.
- return await self.response_cache.wrap(
- txn_id, self._handle_request, request, **kwargs
+ code, response = await self.response_cache.wrap(
+ txn_id, self._handle_request, request, content, **kwargs
)
+ else:
+ # The `@cancellable` decorator may be applied to `_handle_request`. But we
+ # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
+ # so we have to set up the cancellable flag ourselves.
+ request.is_render_cancellable = is_function_cancellable(
+ self._handle_request
+ )
+
+ code, response = await self._handle_request(request, content, **kwargs)
+
+ # Return streams we may have written to in the course of processing this
+ # request.
+ if _STREAM_POSITION_KEY in response:
+ raise Exception("data to send contains %r key", _STREAM_POSITION_KEY)
- # The `@cancellable` decorator may be applied to `_handle_request`. But we
- # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
- # so we have to set up the cancellable flag ourselves.
- request.is_render_cancellable = is_function_cancellable(self._handle_request)
+ if self.WAIT_FOR_STREAMS:
+ response[_STREAM_POSITION_KEY] = {
+ stream.NAME: stream.current_token(self._instance_name)
+ for stream in self._streams
+ }
- return await self._handle_request(request, **kwargs)
+ return code, response
|