diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 2f16955954..c287c4e269 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -17,7 +17,8 @@ import logging
from twisted.internet import defer
-from synapse.events import event_type_from_format_version
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -28,7 +29,7 @@ logger = logging.getLogger(__name__)
class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""Handles events newly received from federation, including persisting and
- notifying.
+ notifying. Returns the maximum stream ID of the persisted events.
The API looks like:
@@ -37,11 +38,21 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
{
"events": [{
"event": { .. serialized event .. },
+ "room_version": .., // "1", "2", "3", etc: the version of the room
+ // containing the event
+ "event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. },
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
}],
"backfilled": false
+ }
+
+ 200 OK
+
+ {
+ "max_stream_id": 32443,
+ }
"""
NAME = "fed_send_events"
@@ -51,6 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.clock = hs.get_clock()
self.federation_handler = hs.get_handlers().federation_handler
@@ -71,6 +83,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_payloads.append(
{
"event": event.get_pdu_json(),
+ "room_version": event.room_version.identifier,
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
@@ -82,8 +95,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
return payload
- @defer.inlineCallbacks
- def _handle_request(self, request):
+ async def _handle_request(self, request):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
@@ -94,26 +106,27 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_and_contexts = []
for event_payload in event_payloads:
event_dict = event_payload["event"]
- format_ver = event_payload["event_format_version"]
+ room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]]
internal_metadata = event_payload["internal_metadata"]
rejected_reason = event_payload["rejected_reason"]
- EventType = event_type_from_format_version(format_ver)
- event = EventType(event_dict, internal_metadata, rejected_reason)
+ event = make_event_from_dict(
+ event_dict, room_ver, internal_metadata, rejected_reason
+ )
- context = yield EventContext.deserialize(
- self.store, event_payload["context"]
+ context = EventContext.deserialize(
+ self.storage, event_payload["context"]
)
event_and_contexts.append((event, context))
logger.info("Got %d events from federation", len(event_and_contexts))
- yield self.federation_handler.persist_events_and_notify(
+ max_stream_id = await self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled
)
- return 200, {}
+ return 200, {"max_stream_id": max_stream_id}
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
@@ -144,8 +157,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
def _serialize_payload(edu_type, origin, content):
return {"origin": origin, "content": content}
- @defer.inlineCallbacks
- def _handle_request(self, request, edu_type):
+ async def _handle_request(self, request, edu_type):
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)
@@ -154,7 +166,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
logger.info("Got %r edu from %s", edu_type, origin)
- result = yield self.registry.on_edu(edu_type, origin, edu_content)
+ result = await self.registry.on_edu(edu_type, origin, edu_content)
return 200, result
@@ -193,8 +205,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
"""
return {"args": args}
- @defer.inlineCallbacks
- def _handle_request(self, request, query_type):
+ async def _handle_request(self, request, query_type):
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)
@@ -202,7 +213,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
logger.info("Got %r query", query_type)
- result = yield self.registry.on_query(query_type, args)
+ result = await self.registry.on_query(query_type, args)
return 200, result
@@ -213,7 +224,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
Request format:
- POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
+ POST /_synapse/replication/fed_cleanup_room/:room_id/:txn_id
{}
"""
@@ -234,10 +245,41 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
"""
return {}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id):
- yield self.store.clean_room_for_join(room_id)
+ async def _handle_request(self, request, room_id):
+ await self.store.clean_room_for_join(room_id)
+
+ return 200, {}
+
+
+class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
+ """Called to clean up any data in DB for a given room, ready for the
+ server to join the room.
+
+ Request format:
+
+ POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id
+
+ {
+ "room_version": "1",
+ }
+ """
+
+ NAME = "store_room_on_invite"
+ PATH_ARGS = ("room_id",)
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.store = hs.get_datastore()
+
+ @staticmethod
+ def _serialize_payload(room_id, room_version):
+ return {"room_version": room_version.identifier}
+ async def _handle_request(self, request, room_id):
+ content = parse_json_object_from_request(request)
+ room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
+ await self.store.maybe_store_room_on_invite(room_id, room_version)
return 200, {}
@@ -246,3 +288,4 @@ def register_servlets(hs, http_server):
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server)
+ ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)
|