summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2023-07-11 15:28:48 -0400
committerPatrick Cloke <patrickc@matrix.org>2023-07-17 11:05:43 -0400
commit23cd415b9e81e247c18495e0b4621316b3a9186a (patch)
treea776eed4274f7aedcbfb0b1c2757dd8b641cef19
parentHandle LPDU content hash. (diff)
downloadsynapse-23cd415b9e81e247c18495e0b4621316b3a9186a.tar.xz
Implement new event and backfill endpoints.
-rw-r--r--synapse/federation/federation_client.py87
-rw-r--r--synapse/federation/federation_server.py35
-rw-r--r--synapse/federation/transport/client.py62
-rw-r--r--synapse/federation/transport/server/federation.py89
4 files changed, 209 insertions, 64 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index e5359ca558..f60ef8c16c 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -325,9 +325,21 @@ class FederationClient(FederationBase):
         if not extremities:
             return None
 
-        transaction_data = await self.transport_layer.backfill(
-            dest, room_id, extremities, limit
-        )
+        try:
+            # Note that this only returns pdus now, but this is close enough to a transaction.
+            transaction_data = await self.transport_layer.backfill_unstable(
+                dest, room_id, extremities, limit
+            )
+        except HttpResponseException as e:
+            # If an error is received that is due to an unrecognised endpoint,
+            # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
+            # and raise.
+            if not is_unknown_endpoint(e):
+                raise
+
+            transaction_data = await self.transport_layer.backfill(
+                dest, room_id, extremities, limit
+            )
 
         logger.debug("backfill transaction_data=%r", transaction_data)
 
@@ -373,45 +385,58 @@ class FederationClient(FederationBase):
         Raises:
             SynapseError, NotRetryingDestination, FederationDeniedError
         """
-        transaction_data = await self.transport_layer.get_event(
-            destination, event_id, timeout=timeout
-        )
+        try:
+            # Note that this only returns pdus now, but this is close enough to a transaction.
+            pdu_json = await self.transport_layer.get_event_unstable(
+                destination, event_id, timeout=timeout
+            )
+
+            pdu = event_from_pdu_json(pdu_json, room_version)
+
+        except HttpResponseException as e:
+            # If an error is received that is due to an unrecognised endpoint,
+            # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
+            # and raise.
+            if not is_unknown_endpoint(e):
+                raise
+
+            transaction_data = await self.transport_layer.get_event(
+                destination, event_id, timeout=timeout
+            )
+
+            pdu_list: List[EventBase] = [
+                event_from_pdu_json(p, room_version) for p in transaction_data["pdus"]
+            ]
+
+            if pdu_list and pdu_list[0]:
+                pdu = pdu_list[0]
+            else:
+                return None
 
         logger.debug(
             "get_pdu_from_destination_raw: retrieved event id %s from %s: %r",
             event_id,
             destination,
-            transaction_data,
+            pdu,
         )
 
-        pdu_list: List[EventBase] = [
-            event_from_pdu_json(p, room_version) for p in transaction_data["pdus"]
-        ]
-
-        if pdu_list and pdu_list[0]:
-            pdu = pdu_list[0]
-
-            # Check signatures are correct.
-            try:
-
-                async def _record_failure_callback(
-                    event: EventBase, cause: str
-                ) -> None:
-                    await self.store.record_event_failed_pull_attempt(
-                        event.room_id, event.event_id, cause
-                    )
+        # Check signatures are correct.
+        try:
 
-                signed_pdu = await self._check_sigs_and_hash(
-                    room_version, pdu, _record_failure_callback
+            async def _record_failure_callback(event: EventBase, cause: str) -> None:
+                await self.store.record_event_failed_pull_attempt(
+                    event.room_id, event.event_id, cause
                 )
-            except InvalidEventSignatureError as e:
-                errmsg = f"event id {pdu.event_id}: {e}"
-                logger.warning("%s", errmsg)
-                raise SynapseError(403, errmsg, Codes.FORBIDDEN)
 
-            return signed_pdu
+            signed_pdu = await self._check_sigs_and_hash(
+                room_version, pdu, _record_failure_callback
+            )
+        except InvalidEventSignatureError as e:
+            errmsg = f"event id {pdu.event_id}: {e}"
+            logger.warning("%s", errmsg)
+            raise SynapseError(403, errmsg, Codes.FORBIDDEN)
 
-        return None
+        return signed_pdu
 
     @trace
     @tag_args
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e26796b408..aff750d0f5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -26,7 +26,6 @@ from typing import (
     Mapping,
     Optional,
     Tuple,
-    Union,
 )
 
 from matrix_common.regex import glob_to_regex
@@ -213,19 +212,15 @@ class FederationServer(FederationBase):
 
     async def on_backfill_request(
         self, origin: str, room_id: str, versions: List[str], limit: int
-    ) -> Tuple[int, Dict[str, Any]]:
+    ) -> List[EventBase]:
         async with self._server_linearizer.queue((origin, room_id)):
             origin_host, _ = parse_server_name(origin)
             await self.check_server_matches_acl(origin_host, room_id)
 
-            pdus = await self.handler.on_backfill_request(
+            return await self.handler.on_backfill_request(
                 origin, room_id, versions, limit
             )
 
-            res = self._transaction_dict_from_pdus(pdus)
-
-        return 200, res
-
     async def on_timestamp_to_event_request(
         self, origin: str, room_id: str, timestamp: int, direction: Direction
     ) -> Tuple[int, Dict[str, Any]]:
@@ -621,15 +616,8 @@ class FederationServer(FederationBase):
             "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
         }
 
-    async def on_pdu_request(
-        self, origin: str, event_id: str
-    ) -> Tuple[int, Union[JsonDict, str]]:
-        pdu = await self.handler.get_persisted_pdu(origin, event_id)
-
-        if pdu:
-            return 200, self._transaction_dict_from_pdus([pdu])
-        else:
-            return 404, ""
+    async def on_pdu_request(self, origin: str, event_id: str) -> Optional[EventBase]:
+        return await self.handler.get_persisted_pdu(origin, event_id)
 
     async def on_query_request(
         self, query_type: str, args: Dict[str, str]
@@ -1078,21 +1066,6 @@ class FederationServer(FederationBase):
         ts_now_ms = self._clock.time_msec()
         return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
 
-    def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict:
-        """Returns a new Transaction containing the given PDUs suitable for
-        transmission.
-        """
-        time_now = self._clock.time_msec()
-        pdus = [p.get_pdu_json(time_now) for p in pdu_list]
-        return Transaction(
-            # Just need a dummy transaction ID and destination since it won't be used.
-            transaction_id="",
-            origin=self.server_name,
-            pdus=pdus,
-            origin_server_ts=int(time_now),
-            destination="",
-        ).get_dict()
-
     async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
         """Process a PDU received in a federation /send/ transaction.
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 0b17f713ea..012ce4710b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -134,6 +134,31 @@ class TransportLayerClient:
             destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
         )
 
+    async def get_event_unstable(
+        self, destination: str, event_id: str, timeout: Optional[int] = None
+    ) -> JsonDict:
+        """Requests the pdu with give id and origin from the given server.
+
+        Args:
+            destination: The host name of the remote homeserver we want
+                to get the state from.
+            event_id: The id of the event being requested.
+            timeout: How long to try (in ms) the destination for before
+                giving up. None indicates no timeout.
+
+        Returns:
+            Results in a dict received from the remote homeserver.
+        """
+        logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
+
+        path = f"/_matrix/federation/unstable/org.matrix.i-d.ralston-mimi-linearized-matrix.02/event/{event_id}"
+        result = await self.client.get_json(
+            destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
+        )
+        # Note that this has many callers, convert the result into the v1 response
+        # (i.e. a transaction).
+        return {"pdus": [result]}
+
     async def backfill(
         self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
     ) -> Optional[Union[JsonDict, list]]:
@@ -171,6 +196,43 @@ class TransportLayerClient:
             destination, path=path, args=args, try_trailing_slash_on_400=True
         )
 
+    async def backfill_unstable(
+        self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
+    ) -> Optional[Union[JsonDict, list]]:
+        """Requests `limit` previous PDUs in a given context before list of
+        PDUs.
+
+        Args:
+            destination
+            room_id
+            event_tuples:
+                Must be a Collection that is falsy when empty.
+                (Iterable is not enough here!)
+            limit
+
+        Returns:
+            Results in a dict received from the remote homeserver.
+        """
+        logger.debug(
+            "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
+            destination,
+            room_id,
+            event_tuples,
+            str(limit),
+        )
+
+        if not event_tuples:
+            # TODO: raise?
+            return None
+
+        path = f"/_matrix/federation/unstable/org.matrix.i-d.ralston-mimi-linearized-matrix.02/backfill/{room_id}"
+
+        args = {"v": event_tuples, "limit": [str(limit)]}
+
+        return await self.client.get_json(
+            destination, path=path, args=args, try_trailing_slash_on_400=True
+        )
+
     async def timestamp_to_event(
         self, destination: str, room_id: str, timestamp: int, direction: Direction
     ) -> Union[JsonDict, List]:
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 3248953b48..d16375aecb 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -35,6 +35,7 @@ from synapse.federation.transport.server._base import (
     Authenticator,
     BaseFederationServlet,
 )
+from synapse.federation.units import Transaction
 from synapse.http.servlet import (
     parse_boolean_from_args,
     parse_integer_from_args,
@@ -67,6 +68,7 @@ class BaseFederationServerServlet(BaseFederationServlet):
     ):
         super().__init__(hs, authenticator, ratelimiter, server_name)
         self.handler = hs.get_federation_server()
+        self._clock = hs.get_clock()
 
 
 class FederationSendServlet(BaseFederationServerServlet):
@@ -150,7 +152,44 @@ class FederationEventServlet(BaseFederationServerServlet):
         query: Dict[bytes, List[bytes]],
         event_id: str,
     ) -> Tuple[int, Union[JsonDict, str]]:
-        return await self.handler.on_pdu_request(origin, event_id)
+        event = await self.handler.on_pdu_request(origin, event_id)
+
+        if event:
+            # Returns a new Transaction containing the given PDUs suitable for transmission.
+            time_now = self._clock.time_msec()
+            pdus = [event.get_pdu_json(time_now)]
+            return (
+                200,
+                Transaction(
+                    # Just need a dummy transaction ID and destination since it won't be used.
+                    transaction_id="",
+                    origin=self.server_name,
+                    pdus=pdus,
+                    origin_server_ts=int(time_now),
+                    destination="",
+                ).get_dict(),
+            )
+
+        return 404, ""
+
+
+class FederationUnstableEventServlet(BaseFederationServerServlet):
+    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.i-d.ralston-mimi-linearized-matrix.02"
+    PATH = "/event/(?P<event_id>[^/]*)/?"
+    CATEGORY = "Federation requests"
+
+    # This is when someone asks for a data item for a given server data_id pair.
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        event_id: str,
+    ) -> Tuple[int, Union[JsonDict, str]]:
+        event = await self.handler.on_pdu_request(origin, event_id)
+        if event:
+            return 200, event.get_dict()
+        return 404, ""
 
 
 class FederationStateV1Servlet(BaseFederationServerServlet):
@@ -207,7 +246,50 @@ class FederationBackfillServlet(BaseFederationServerServlet):
         if not limit:
             return 400, {"error": "Did not include limit param"}
 
-        return await self.handler.on_backfill_request(origin, room_id, versions, limit)
+        pdu_list = await self.handler.on_backfill_request(
+            origin, room_id, versions, limit
+        )
+
+        # Returns a new Transaction containing the given PDUs suitable for transmission.
+        time_now = self._clock.time_msec()
+        pdus = [p.get_pdu_json(time_now) for p in pdu_list]
+        return (
+            200,
+            Transaction(
+                # Just need a dummy transaction ID and destination since it won't be used.
+                transaction_id="",
+                origin=self.server_name,
+                pdus=pdus,
+                origin_server_ts=int(time_now),
+                destination="",
+            ).get_dict(),
+        )
+
+
+class FederationUnstableBackfillServlet(BaseFederationServerServlet):
+    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.i-d.ralston-mimi-linearized-matrix.02"
+    PATH = "/backfill/(?P<room_id>[^/]*)/?"
+    CATEGORY = "Federation requests"
+
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
+        versions = [x.decode("ascii") for x in query[b"v"]]
+        # TODO(LM) Only a single version is allowed for Linearized Matrix.
+        limit = parse_integer_from_args(query, "limit", None)
+
+        if not limit:
+            return 400, {"error": "Did not include limit param"}
+
+        pdu_list = await self.handler.on_backfill_request(
+            origin, room_id, versions, limit
+        )
+
+        return 200, {"pdus": [p.get_pdu_json() for p in pdu_list]}
 
 
 class FederationTimestampLookupServlet(BaseFederationServerServlet):
@@ -811,4 +893,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationV1SendKnockServlet,
     FederationMakeKnockServlet,
     FederationAccountStatusServlet,
+    # TODO(LM) Linearized Matrix additions.
+    FederationUnstableEventServlet,
+    FederationUnstableBackfillServlet,
 )