summary refs log tree commit diff
path: root/synapse/replication/http/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/http/federation.py')
-rw-r--r--synapse/replication/http/federation.py85
1 files changed, 64 insertions, 21 deletions
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 2f16955954..c287c4e269 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -17,7 +17,8 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.events import event_type_from_format_version
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import make_event_from_dict
 from synapse.events.snapshot import EventContext
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
@@ -28,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
     """Handles events newly received from federation, including persisting and
-    notifying.
+    notifying. Returns the maximum stream ID of the persisted events.
 
     The API looks like:
 
@@ -37,11 +38,21 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         {
             "events": [{
                 "event": { .. serialized event .. },
+                "room_version": .., // "1", "2", "3", etc: the version of the room
+                                    // containing the event
+                "event_format_version": .., // 1,2,3 etc: the event format version
                 "internal_metadata": { .. serialized internal_metadata .. },
                 "rejected_reason": ..,   // The event.rejected_reason field
                 "context": { .. serialized event context .. },
             }],
             "backfilled": false
+        }
+
+        200 OK
+
+        {
+            "max_stream_id": 32443,
+        }
     """
 
     NAME = "fed_send_events"
@@ -51,6 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
 
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.clock = hs.get_clock()
         self.federation_handler = hs.get_handlers().federation_handler
 
@@ -71,6 +83,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
             event_payloads.append(
                 {
                     "event": event.get_pdu_json(),
+                    "room_version": event.room_version.identifier,
                     "event_format_version": event.format_version,
                     "internal_metadata": event.internal_metadata.get_dict(),
                     "rejected_reason": event.rejected_reason,
@@ -82,8 +95,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
 
         return payload
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request):
+    async def _handle_request(self, request):
         with Measure(self.clock, "repl_fed_send_events_parse"):
             content = parse_json_object_from_request(request)
 
@@ -94,26 +106,27 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
             event_and_contexts = []
             for event_payload in event_payloads:
                 event_dict = event_payload["event"]
-                format_ver = event_payload["event_format_version"]
+                room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]]
                 internal_metadata = event_payload["internal_metadata"]
                 rejected_reason = event_payload["rejected_reason"]
 
-                EventType = event_type_from_format_version(format_ver)
-                event = EventType(event_dict, internal_metadata, rejected_reason)
+                event = make_event_from_dict(
+                    event_dict, room_ver, internal_metadata, rejected_reason
+                )
 
-                context = yield EventContext.deserialize(
-                    self.store, event_payload["context"]
+                context = EventContext.deserialize(
+                    self.storage, event_payload["context"]
                 )
 
                 event_and_contexts.append((event, context))
 
         logger.info("Got %d events from federation", len(event_and_contexts))
 
-        yield self.federation_handler.persist_events_and_notify(
+        max_stream_id = await self.federation_handler.persist_events_and_notify(
             event_and_contexts, backfilled
         )
 
-        return 200, {}
+        return 200, {"max_stream_id": max_stream_id}
 
 
 class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
@@ -144,8 +157,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
     def _serialize_payload(edu_type, origin, content):
         return {"origin": origin, "content": content}
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, edu_type):
+    async def _handle_request(self, request, edu_type):
         with Measure(self.clock, "repl_fed_send_edu_parse"):
             content = parse_json_object_from_request(request)
 
@@ -154,7 +166,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
 
         logger.info("Got %r edu from %s", edu_type, origin)
 
-        result = yield self.registry.on_edu(edu_type, origin, edu_content)
+        result = await self.registry.on_edu(edu_type, origin, edu_content)
 
         return 200, result
 
@@ -193,8 +205,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
         """
         return {"args": args}
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, query_type):
+    async def _handle_request(self, request, query_type):
         with Measure(self.clock, "repl_fed_query_parse"):
             content = parse_json_object_from_request(request)
 
@@ -202,7 +213,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
 
         logger.info("Got %r query", query_type)
 
-        result = yield self.registry.on_query(query_type, args)
+        result = await self.registry.on_query(query_type, args)
 
         return 200, result
 
@@ -213,7 +224,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
 
     Request format:
 
-        POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
+        POST /_synapse/replication/fed_cleanup_room/:room_id/:txn_id
 
         {}
     """
@@ -234,10 +245,41 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
         """
         return {}
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, room_id):
-        yield self.store.clean_room_for_join(room_id)
+    async def _handle_request(self, request, room_id):
+        await self.store.clean_room_for_join(room_id)
+
+        return 200, {}
+
+
+class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
+    """Called to clean up any data in DB for a given room, ready for the
+    server to join the room.
+
+    Request format:
+
+        POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id
+
+        {
+            "room_version": "1",
+        }
+    """
+
+    NAME = "store_room_on_invite"
+    PATH_ARGS = ("room_id",)
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.store = hs.get_datastore()
+
+    @staticmethod
+    def _serialize_payload(room_id, room_version):
+        return {"room_version": room_version.identifier}
 
+    async def _handle_request(self, request, room_id):
+        content = parse_json_object_from_request(request)
+        room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
+        await self.store.maybe_store_room_on_invite(room_id, room_version)
         return 200, {}
 
 
@@ -246,3 +288,4 @@ def register_servlets(hs, http_server):
     ReplicationFederationSendEduRestServlet(hs).register(http_server)
     ReplicationGetQueryRestServlet(hs).register(http_server)
     ReplicationCleanRoomRestServlet(hs).register(http_server)
+    ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)