diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 6df000faaf..904a721483 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -2259,6 +2259,10 @@ class FederationEventHandler:
event_and_contexts, backfilled=backfilled
)
+ # After persistence we always need to notify replication there may
+ # be new data.
+ self._notifier.notify_replication()
+
if self._ephemeral_messages_enabled:
for event in events:
# If there's an expiry timestamp on the event, schedule its expiry.
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 3f4d3fc51a..709327b97f 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -17,7 +17,7 @@ import logging
import re
import urllib.parse
from inspect import signature
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple
from prometheus_client import Counter, Gauge
@@ -27,6 +27,7 @@ from twisted.web.server import Request
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer
+from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.opentracing import trace_with_opname
@@ -53,6 +54,9 @@ _outgoing_request_counter = Counter(
)
+_STREAM_POSITION_KEY = "_INT_STREAM_POS"
+
+
class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""Helper base class for defining new replication HTTP endpoints.
@@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
a connection error is received.
RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when
receiving connection errors, each will backoff exponentially longer.
+ WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to
+ catch up before processing the request and/or response. Defaults to
+ True.
"""
NAME: str = abc.abstractproperty() # type: ignore
@@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
RETRY_ON_CONNECT_ERROR = True
RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1)
+ WAIT_FOR_STREAMS: ClassVar[bool] = True
+
def __init__(self, hs: "HomeServer"):
if self.CACHE:
self.response_cache: ResponseCache[str] = ResponseCache(
@@ -126,6 +135,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret
+ self._streams = hs.get_replication_command_handler().get_streams_to_replicate()
+ self._replication = hs.get_replication_data_handler()
+ self._instance_name = hs.get_instance_name()
+
def _check_auth(self, request: Request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
@@ -160,7 +173,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def _handle_request(
- self, request: Request, **kwargs: Any
+ self, request: Request, content: JsonDict, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Handle incoming request.
@@ -201,6 +214,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@trace_with_opname("outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
+ # We have to pull these out here to avoid circular dependencies...
+ streams = hs.get_replication_command_handler().get_streams_to_replicate()
+ replication = hs.get_replication_data_handler()
+
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
@@ -219,6 +236,24 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
data = await cls._serialize_payload(**kwargs)
+ if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS:
+ # Include the current stream positions that we write to. We
+ # don't do this for GETs as they don't have a body, and we
+ # generally assume that a GET won't rely on data we have
+ # written.
+ if _STREAM_POSITION_KEY in data:
+ raise Exception(
+ "data to send contains %r key", _STREAM_POSITION_KEY
+ )
+
+ data[_STREAM_POSITION_KEY] = {
+ "streams": {
+ stream.NAME: stream.current_token(local_instance_name)
+ for stream in streams
+ },
+ "instance_name": local_instance_name,
+ }
+
url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
]
@@ -308,6 +343,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
) from e
_outgoing_request_counter.labels(cls.NAME, 200).inc()
+
+ # Wait on any streams that the remote may have written to.
+ for stream_name, position in result.get(
+ _STREAM_POSITION_KEY, {}
+ ).items():
+ await replication.wait_for_stream_position(
+ instance_name=instance_name,
+ stream_name=stream_name,
+ position=position,
+ raise_on_timeout=False,
+ )
+
return result
return send_request
@@ -353,6 +400,23 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if self._replication_secret:
self._check_auth(request)
+ if self.METHOD == "GET":
+ # GET APIs always have an empty body.
+ content = {}
+ else:
+ content = parse_json_object_from_request(request)
+
+ # Wait on any streams that the remote may have written to.
+ for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[
+ "streams"
+ ].items():
+ await self._replication.wait_for_stream_position(
+ instance_name=content[_STREAM_POSITION_KEY]["instance_name"],
+ stream_name=stream_name,
+ position=position,
+ raise_on_timeout=False,
+ )
+
if self.CACHE:
txn_id = kwargs.pop("txn_id")
@@ -361,13 +425,28 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# correctly yet. In particular, there may be issues to do with logging
# context lifetimes.
- return await self.response_cache.wrap(
- txn_id, self._handle_request, request, **kwargs
+ code, response = await self.response_cache.wrap(
+ txn_id, self._handle_request, request, content, **kwargs
)
+ else:
+ # The `@cancellable` decorator may be applied to `_handle_request`. But we
+ # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
+ # so we have to set up the cancellable flag ourselves.
+ request.is_render_cancellable = is_function_cancellable(
+ self._handle_request
+ )
+
+ code, response = await self._handle_request(request, content, **kwargs)
+
+ # Return streams we may have written to in the course of processing this
+ # request.
+ if _STREAM_POSITION_KEY in response:
+ raise Exception("data to send contains %r key", _STREAM_POSITION_KEY)
- # The `@cancellable` decorator may be applied to `_handle_request`. But we
- # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
- # so we have to set up the cancellable flag ourselves.
- request.is_render_cancellable = is_function_cancellable(self._handle_request)
+ if self.WAIT_FOR_STREAMS:
+ response[_STREAM_POSITION_KEY] = {
+ stream.NAME: stream.current_token(self._instance_name)
+ for stream in self._streams
+ }
- return await self._handle_request(request, **kwargs)
+ return code, response
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
index 0edc95977b..2374f810c9 100644
--- a/synapse/replication/http/account_data.py
+++ b/synapse/replication/http/account_data.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -61,10 +60,8 @@ class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str, account_data_type: str
+ self, request: Request, content: JsonDict, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
max_stream_id = await self.handler.add_account_data_for_user(
user_id, account_data_type, content["content"]
)
@@ -101,7 +98,7 @@ class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str, account_data_type: str
+ self, request: Request, content: JsonDict, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_account_data_for_user(
user_id, account_data_type
@@ -143,10 +140,13 @@ class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str, room_id: str, account_data_type: str
+ self,
+ request: Request,
+ content: JsonDict,
+ user_id: str,
+ room_id: str,
+ account_data_type: str,
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
max_stream_id = await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, content["content"]
)
@@ -183,7 +183,12 @@ class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str, room_id: str, account_data_type: str
+ self,
+ request: Request,
+ content: JsonDict,
+ user_id: str,
+ room_id: str,
+ account_data_type: str,
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_account_data_for_room(
user_id, room_id, account_data_type
@@ -225,10 +230,8 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str, room_id: str, tag: str
+ self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
max_stream_id = await self.handler.add_tag_to_room(
user_id, room_id, tag, content["content"]
)
@@ -266,7 +269,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str, room_id: str, tag: str
+ self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_tag_from_room(
user_id,
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index ea5c08e6cf..ecea6fc915 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.logging.opentracing import active_span
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -78,7 +77,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, Optional[JsonDict]]:
user_devices = await self.device_list_updater.user_device_resync(user_id)
@@ -138,9 +137,8 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
return {"user_ids": user_ids}
async def _handle_request( # type: ignore[override]
- self, request: Request
+ self, request: Request, content: JsonDict
) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
- content = parse_json_object_from_request(request)
user_ids: List[str] = content["user_ids"]
logger.info("Resync for %r", user_ids)
@@ -205,10 +203,8 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: Request
+ self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
user_id = content["user_id"]
device_id = content["device_id"]
keys = content["keys"]
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index d3abafed28..53ad327030 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
from synapse.util.metrics import Measure
@@ -114,10 +113,8 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
return payload
- async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override]
+ async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override]
with Measure(self.clock, "repl_fed_send_events_parse"):
- content = parse_json_object_from_request(request)
-
room_id = content["room_id"]
backfilled = content["backfilled"]
@@ -181,13 +178,10 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
return {"origin": origin, "content": content}
async def _handle_request( # type: ignore[override]
- self, request: Request, edu_type: str
+ self, request: Request, content: JsonDict, edu_type: str
) -> Tuple[int, JsonDict]:
- with Measure(self.clock, "repl_fed_send_edu_parse"):
- content = parse_json_object_from_request(request)
-
- origin = content["origin"]
- edu_content = content["content"]
+ origin = content["origin"]
+ edu_content = content["content"]
logger.info("Got %r edu from %s", edu_type, origin)
@@ -231,13 +225,10 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
return {"args": args}
async def _handle_request( # type: ignore[override]
- self, request: Request, query_type: str
+ self, request: Request, content: JsonDict, query_type: str
) -> Tuple[int, JsonDict]:
- with Measure(self.clock, "repl_fed_query_parse"):
- content = parse_json_object_from_request(request)
-
- args = content["args"]
- args["origin"] = content["origin"]
+ args = content["args"]
+ args["origin"] = content["origin"]
logger.info("Got %r query from %s", query_type, args["origin"])
@@ -274,7 +265,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, room_id: str
+ self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]:
await self.store.clean_room_for_join(room_id)
@@ -307,9 +298,8 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
return {"room_version": room_version.identifier}
async def _handle_request( # type: ignore[override]
- self, request: Request, room_id: str
+ self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {}
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index c68e18da12..6ad6cb1bfe 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple, cast
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -73,10 +72,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 663bff5738..9fa1060d48 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
@@ -79,10 +78,8 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: SynapseRequest, room_id: str, user_id: str
+ self, request: SynapseRequest, content: JsonDict, room_id: str, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
remote_room_hosts = content["remote_room_hosts"]
event_content = content["content"]
@@ -147,11 +144,10 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
async def _handle_request( # type: ignore[override]
self,
request: SynapseRequest,
+ content: JsonDict,
room_id: str,
user_id: str,
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
remote_room_hosts = content["remote_room_hosts"]
event_content = content["content"]
@@ -217,10 +213,8 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: SynapseRequest, invite_event_id: str
+ self, request: SynapseRequest, content: JsonDict, invite_event_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
txn_id = content["txn_id"]
event_content = content["content"]
@@ -285,10 +279,9 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
async def _handle_request( # type: ignore[override]
self,
request: SynapseRequest,
+ content: JsonDict,
knock_event_id: str,
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
txn_id = content["txn_id"]
event_content = content["content"]
@@ -347,7 +340,12 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, room_id: str, user_id: str, change: str
+ self,
+ request: Request,
+ content: JsonDict,
+ room_id: str,
+ user_id: str,
+ change: str,
) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index 4a5b08f56f..db16aac9c2 100644
--- a/synapse/replication/http/presence.py
+++ b/synapse/replication/http/presence.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, UserID
@@ -56,7 +55,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
await self._presence_handler.bump_presence_active_time(
UserID.from_string(user_id)
@@ -107,10 +106,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
await self._presence_handler.set_state(
UserID.from_string(user_id),
content["state"],
diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py
index af5c2f66a7..297e8ad564 100644
--- a/synapse/replication/http/push.py
+++ b/synapse/replication/http/push.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -61,10 +60,8 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
app_id = content["app_id"]
pushkey = content["pushkey"]
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 976c283360..265e601b96 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -96,10 +95,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
await self.registration_handler.check_registration_ratelimit(content["address"])
# Always default admin users to approved (since it means they were created by
@@ -150,10 +147,8 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
return {"auth_result": auth_result, "access_token": access_token}
async def _handle_request( # type: ignore[override]
- self, request: Request, user_id: str
+ self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
-
auth_result = content["auth_result"]
access_token = content["access_token"]
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 4215a1c1bc..27ad914075 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
from synapse.util.metrics import Measure
@@ -114,11 +113,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request, event_id: str
+ self, request: Request, content: JsonDict, event_id: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_send_event_parse"):
- content = parse_json_object_from_request(request)
-
event_dict = content["event"]
room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]]
internal_metadata = content["internal_metadata"]
diff --git a/synapse/replication/http/send_events.py b/synapse/replication/http/send_events.py
index 8889bbb644..4f82c9f96d 100644
--- a/synapse/replication/http/send_events.py
+++ b/synapse/replication/http/send_events.py
@@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
from synapse.util.metrics import Measure
@@ -114,10 +113,9 @@ class ReplicationSendEventsRestServlet(ReplicationEndpoint):
return payload
async def _handle_request( # type: ignore[override]
- self, request: Request
+ self, request: Request, payload: JsonDict
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_send_events_parse"):
- payload = parse_json_object_from_request(request)
events_and_context = []
events = payload["events"]
diff --git a/synapse/replication/http/state.py b/synapse/replication/http/state.py
index 838b7584e5..0c524e7de3 100644
--- a/synapse/replication/http/state.py
+++ b/synapse/replication/http/state.py
@@ -57,7 +57,7 @@ class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore[override]
- self, request: Request, room_id: str
+ self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]:
writer_instance = self._events_shard_config.get_instance(room_id)
if writer_instance != self._instance_name:
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index c065225362..3c7b5b18ea 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -54,6 +54,10 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
PATH_ARGS = ("stream_name",)
METHOD = "GET"
+ # We don't want to wait for replication streams to catch up, as this gets
+ # called in the process of catching replication streams up.
+ WAIT_FOR_STREAMS = False
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -67,7 +71,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
return {"from_token": from_token, "upto_token": upto_token}
async def _handle_request( # type: ignore[override]
- self, request: Request, stream_name: str
+ self, request: Request, content: JsonDict, stream_name: str
) -> Tuple[int, JsonDict]:
stream = self.streams.get(stream_name)
if stream is None:
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 322d695bc7..5c2482e40c 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,6 +16,7 @@
import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory
@@ -314,10 +315,21 @@ class ReplicationDataHandler:
self.send_handler.wake_destination(server)
async def wait_for_stream_position(
- self, instance_name: str, stream_name: str, position: int
+ self,
+ instance_name: str,
+ stream_name: str,
+ position: int,
+ raise_on_timeout: bool = True,
) -> None:
"""Wait until this instance has received updates up to and including
the given stream position.
+
+ Args:
+ instance_name
+ stream_name
+ position
+ raise_on_timeout: Whether to raise an exception if we time out
+ waiting for the updates, or if we log an error and return.
"""
if instance_name == self._instance_name:
@@ -345,7 +357,16 @@ class ReplicationDataHandler:
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
- await make_deferred_yieldable(deferred)
+ try:
+ await make_deferred_yieldable(deferred)
+ except defer.TimeoutError:
+ logger.error("Timed out waiting for stream %s", stream_name)
+
+ if raise_on_timeout:
+ raise
+
+ return
+
logger.info(
"Finished waiting for repl stream %r to reach %s", stream_name, position
)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 99f09669f0..9d17eff714 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -199,33 +199,28 @@ class ReplicationStreamer:
# The token has advanced but there is no data to
# send, so we send a `POSITION` to inform other
# workers of the updated position.
- if stream.NAME == EventsStream.NAME:
- # XXX: We only do this for the EventStream as it
- # turns out that e.g. account data streams share
- # their "current token" with each other, meaning
- # that it is *not* safe to send a POSITION.
-
- # Note: `last_token` may not *actually* be the
- # last token we sent out in a RDATA or POSITION.
- # This can happen if we sent out an RDATA for
- # position X when our current token was say X+1.
- # Other workers will see RDATA for X and then a
- # POSITION with last token of X+1, which will
- # cause them to check if there were any missing
- # updates between X and X+1.
- logger.info(
- "Sending position: %s -> %s",
+
+ # Note: `last_token` may not *actually* be the
+ # last token we sent out in a RDATA or POSITION.
+ # This can happen if we sent out an RDATA for
+ # position X when our current token was say X+1.
+ # Other workers will see RDATA for X and then a
+ # POSITION with last token of X+1, which will
+ # cause them to check if there were any missing
+ # updates between X and X+1.
+ logger.info(
+ "Sending position: %s -> %s",
+ stream.NAME,
+ current_token,
+ )
+ self.command_handler.send_command(
+ PositionCommand(
stream.NAME,
+ self._instance_name,
+ last_token,
current_token,
)
- self.command_handler.send_command(
- PositionCommand(
- stream.NAME,
- self._instance_name,
- last_token,
- current_token,
- )
- )
+ )
continue
# Some streams return multiple rows with the same stream IDs,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 0d7108f01b..8670ffbfa3 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -378,6 +378,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._current_positions.values(), default=1
)
+ if not writers:
+ # If there have been no explicit writers given then any instance can
+ # write to the stream. In which case, let's pre-seed our own
+ # position with the current minimum.
+ self._current_positions[self._instance_name] = self._persisted_upto_position
+
def _load_current_ids(
self,
db_conn: LoggingDatabaseConnection,
@@ -695,24 +701,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
heapq.heappush(self._known_persisted_positions, new_id)
- # If we're a writer and we don't have any active writes we update our
- # current position to the latest position seen. This allows the instance
- # to report a recent position when asked, rather than a potentially old
- # one (if this instance hasn't written anything for a while).
- our_current_position = self._current_positions.get(self._instance_name)
- if (
- our_current_position
- and not self._unfinished_ids
- and not self._in_flight_fetches
- ):
- self._current_positions[self._instance_name] = max(
- our_current_position, new_id
- )
-
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
- min_curr = min(self._current_positions.values(), default=0)
+ our_current_position = self._current_positions.get(self._instance_name, 0)
+ min_curr = min(
+ (
+ token
+ for name, token in self._current_positions.items()
+ if name != self._instance_name
+ ),
+ default=our_current_position,
+ )
+
+ if our_current_position and (self._unfinished_ids or self._in_flight_fetches):
+ min_curr = min(min_curr, our_current_position)
+
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 0c725eb967..c59eca2430 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -604,6 +604,12 @@ class RoomStreamToken:
elif self.instance_map:
entries = []
for name, pos in self.instance_map.items():
+ if pos <= self.stream:
+ # Ignore instances who are below the minimum stream position
+ # (we might know they've advanced without seeing a recent
+ # write from them).
+ continue
+
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
|