summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_server.py50
-rw-r--r--synapse/federation/persistence.py4
-rw-r--r--synapse/federation/sender/transaction_manager.py9
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server.py11
-rw-r--r--synapse/federation/units.py90
6 files changed, 74 insertions, 92 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 145b9161d9..0385aadefa 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -195,13 +195,17 @@ class FederationServer(FederationBase):
                 origin, room_id, versions, limit
             )
 
-            res = self._transaction_from_pdus(pdus).get_dict()
+            res = self._transaction_dict_from_pdus(pdus)
 
         return 200, res
 
     async def on_incoming_transaction(
-        self, origin: str, transaction_data: JsonDict
-    ) -> Tuple[int, Dict[str, Any]]:
+        self,
+        origin: str,
+        transaction_id: str,
+        destination: str,
+        transaction_data: JsonDict,
+    ) -> Tuple[int, JsonDict]:
         # If we receive a transaction we should make sure that kick off handling
         # any old events in the staging area.
         if not self._started_handling_of_staged_events:
@@ -212,8 +216,14 @@ class FederationServer(FederationBase):
         # accurate as possible.
         request_time = self._clock.time_msec()
 
-        transaction = Transaction(**transaction_data)
-        transaction_id = transaction.transaction_id  # type: ignore
+        transaction = Transaction(
+            transaction_id=transaction_id,
+            destination=destination,
+            origin=origin,
+            origin_server_ts=transaction_data.get("origin_server_ts"),  # type: ignore
+            pdus=transaction_data.get("pdus"),  # type: ignore
+            edus=transaction_data.get("edus"),
+        )
 
         if not transaction_id:
             raise Exception("Transaction missing transaction_id")
@@ -221,9 +231,7 @@ class FederationServer(FederationBase):
         logger.debug("[%s] Got transaction", transaction_id)
 
         # Reject malformed transactions early: reject if too many PDUs/EDUs
-        if len(transaction.pdus) > 50 or (  # type: ignore
-            hasattr(transaction, "edus") and len(transaction.edus) > 100  # type: ignore
-        ):
+        if len(transaction.pdus) > 50 or len(transaction.edus) > 100:
             logger.info("Transaction PDU or EDU count too large. Returning 400")
             return 400, {}
 
@@ -263,7 +271,7 @@ class FederationServer(FederationBase):
         # CRITICAL SECTION: the first thing we must do (before awaiting) is
         # add an entry to _active_transactions.
         assert origin not in self._active_transactions
-        self._active_transactions[origin] = transaction.transaction_id  # type: ignore
+        self._active_transactions[origin] = transaction.transaction_id
 
         try:
             result = await self._handle_incoming_transaction(
@@ -291,11 +299,11 @@ class FederationServer(FederationBase):
         if response:
             logger.debug(
                 "[%s] We've already responded to this request",
-                transaction.transaction_id,  # type: ignore
+                transaction.transaction_id,
             )
             return response
 
-        logger.debug("[%s] Transaction is new", transaction.transaction_id)  # type: ignore
+        logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
         # We process PDUs and EDUs in parallel. This is important as we don't
         # want to block things like to device messages from reaching clients
@@ -334,7 +342,7 @@ class FederationServer(FederationBase):
             report back to the sending server.
         """
 
-        received_pdus_counter.inc(len(transaction.pdus))  # type: ignore
+        received_pdus_counter.inc(len(transaction.pdus))
 
         origin_host, _ = parse_server_name(origin)
 
@@ -342,7 +350,7 @@ class FederationServer(FederationBase):
 
         newest_pdu_ts = 0
 
-        for p in transaction.pdus:  # type: ignore
+        for p in transaction.pdus:
             # FIXME (richardv): I don't think this works:
             #  https://github.com/matrix-org/synapse/issues/8429
             if "unsigned" in p:
@@ -436,10 +444,10 @@ class FederationServer(FederationBase):
 
         return pdu_results
 
-    async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
+    async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None:
         """Process the EDUs in a received transaction."""
 
-        async def _process_edu(edu_dict):
+        async def _process_edu(edu_dict: JsonDict) -> None:
             received_edus_counter.inc()
 
             edu = Edu(
@@ -452,7 +460,7 @@ class FederationServer(FederationBase):
 
         await concurrently_execute(
             _process_edu,
-            getattr(transaction, "edus", []),
+            transaction.edus,
             TRANSACTION_CONCURRENCY_LIMIT,
         )
 
@@ -538,7 +546,7 @@ class FederationServer(FederationBase):
         pdu = await self.handler.get_persisted_pdu(origin, event_id)
 
         if pdu:
-            return 200, self._transaction_from_pdus([pdu]).get_dict()
+            return 200, self._transaction_dict_from_pdus([pdu])
         else:
             return 404, ""
 
@@ -879,18 +887,20 @@ class FederationServer(FederationBase):
         ts_now_ms = self._clock.time_msec()
         return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
 
-    def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction:
+    def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict:
         """Returns a new Transaction containing the given PDUs suitable for
         transmission.
         """
         time_now = self._clock.time_msec()
         pdus = [p.get_pdu_json(time_now) for p in pdu_list]
         return Transaction(
+            # Just need a dummy transaction ID and destination since it won't be used.
+            transaction_id="",
             origin=self.server_name,
             pdus=pdus,
             origin_server_ts=int(time_now),
-            destination=None,
-        )
+            destination="",
+        ).get_dict()
 
     async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
         """Process a PDU received in a federation /send/ transaction.
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 2f9c9bc2cd..4fead6ca29 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -45,7 +45,7 @@ class TransactionActions:
             `None` if we have not previously responded to this transaction or a
             2-tuple of `(int, dict)` representing the response code and response body.
         """
-        transaction_id = transaction.transaction_id  # type: ignore
+        transaction_id = transaction.transaction_id
         if not transaction_id:
             raise RuntimeError("Cannot persist a transaction with no transaction_id")
 
@@ -56,7 +56,7 @@ class TransactionActions:
         self, origin: str, transaction: Transaction, code: int, response: JsonDict
     ) -> None:
         """Persist how we responded to a transaction."""
-        transaction_id = transaction.transaction_id  # type: ignore
+        transaction_id = transaction.transaction_id
         if not transaction_id:
             raise RuntimeError("Cannot persist a transaction with no transaction_id")
 
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 72a635830b..dc555cca0b 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -27,6 +27,7 @@ from synapse.logging.opentracing import (
     tags,
     whitelisted_homeserver,
 )
+from synapse.types import JsonDict
 from synapse.util import json_decoder
 from synapse.util.metrics import measure_func
 
@@ -104,13 +105,13 @@ class TransactionManager:
                 len(edus),
             )
 
-            transaction = Transaction.create_new(
+            transaction = Transaction(
                 origin_server_ts=int(self.clock.time_msec()),
                 transaction_id=txn_id,
                 origin=self._server_name,
                 destination=destination,
-                pdus=pdus,
-                edus=edus,
+                pdus=[p.get_pdu_json() for p in pdus],
+                edus=[edu.get_dict() for edu in edus],
             )
 
             self._next_txn_id += 1
@@ -131,7 +132,7 @@ class TransactionManager:
             # FIXME (richardv): I also believe it no longer works. We (now?) store
             #  "age_ts" in "unsigned" rather than at the top level. See
             #  https://github.com/matrix-org/synapse/issues/8429.
-            def json_data_cb():
+            def json_data_cb() -> JsonDict:
                 data = transaction.get_dict()
                 now = int(self.clock.time_msec())
                 if "pdus" in data:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 6a8d3ad4fe..90a7c16b62 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -143,7 +143,7 @@ class TransportLayerClient:
         """Sends the given Transaction to its destination
 
         Args:
-            transaction (Transaction)
+            transaction
 
         Returns:
             Succeeds when we get a 2xx HTTP response. The result
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 5e059d6e09..640f46fff6 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -450,21 +450,12 @@ class FederationSendServlet(BaseFederationServerServlet):
                 len(transaction_data.get("edus", [])),
             )
 
-            # We should ideally be getting this from the security layer.
-            # origin = body["origin"]
-
-            # Add some extra data to the transaction dict that isn't included
-            # in the request body.
-            transaction_data.update(
-                transaction_id=transaction_id, destination=self.server_name
-            )
-
         except Exception as e:
             logger.exception(e)
             return 400, {"error": "Invalid transaction"}
 
         code, response = await self.handler.on_incoming_transaction(
-            origin, transaction_data
+            origin, transaction_id, self.server_name, transaction_data
         )
 
         return code, response
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index c83a261918..b9b12fbea5 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -17,18 +17,17 @@ server protocol.
 """
 
 import logging
-from typing import Optional
+from typing import List, Optional
 
 import attr
 
 from synapse.types import JsonDict
-from synapse.util.jsonobject import JsonEncodedObject
 
 logger = logging.getLogger(__name__)
 
 
-@attr.s(slots=True)
-class Edu(JsonEncodedObject):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Edu:
     """An Edu represents a piece of data sent from one homeserver to another.
 
     In comparison to Pdus, Edus are not persisted for a long time on disk, are
@@ -36,10 +35,10 @@ class Edu(JsonEncodedObject):
     internal ID or previous references graph.
     """
 
-    edu_type = attr.ib(type=str)
-    content = attr.ib(type=dict)
-    origin = attr.ib(type=str)
-    destination = attr.ib(type=str)
+    edu_type: str
+    content: dict
+    origin: str
+    destination: str
 
     def get_dict(self) -> JsonDict:
         return {
@@ -55,14 +54,21 @@ class Edu(JsonEncodedObject):
             "destination": self.destination,
         }
 
-    def get_context(self):
+    def get_context(self) -> str:
         return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")
 
-    def strip_context(self):
+    def strip_context(self) -> None:
         getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"
 
 
-class Transaction(JsonEncodedObject):
+def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]:
+    if edus is None:
+        return []
+    return edus
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Transaction:
     """A transaction is a list of Pdus and Edus to be sent to a remote home
     server with some extra metadata.
 
@@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject):
 
     """
 
-    valid_keys = [
-        "transaction_id",
-        "origin",
-        "destination",
-        "origin_server_ts",
-        "previous_ids",
-        "pdus",
-        "edus",
-    ]
-
-    internal_keys = ["transaction_id", "destination"]
-
-    required_keys = [
-        "transaction_id",
-        "origin",
-        "destination",
-        "origin_server_ts",
-        "pdus",
-    ]
-
-    def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs):
-        """If we include a list of pdus then we decode then as PDU's
-        automatically.
-        """
-
-        # If there's no EDUs then remove the arg
-        if "edus" in kwargs and not kwargs["edus"]:
-            del kwargs["edus"]
-
-        super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs)
-
-    @staticmethod
-    def create_new(pdus, **kwargs):
-        """Used to create a new transaction. Will auto fill out
-        transaction_id and origin_server_ts keys.
-        """
-        if "origin_server_ts" not in kwargs:
-            raise KeyError("Require 'origin_server_ts' to construct a Transaction")
-        if "transaction_id" not in kwargs:
-            raise KeyError("Require 'transaction_id' to construct a Transaction")
-
-        kwargs["pdus"] = [p.get_pdu_json() for p in pdus]
-
-        return Transaction(**kwargs)
+    # Required keys.
+    transaction_id: str
+    origin: str
+    destination: str
+    origin_server_ts: int
+    pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
+    edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
+
+    def get_dict(self) -> JsonDict:
+        """A JSON-ready dictionary of valid keys which aren't internal."""
+        result = {
+            "origin": self.origin,
+            "origin_server_ts": self.origin_server_ts,
+            "pdus": self.pdus,
+        }
+        if self.edus:
+            result["edus"] = self.edus
+        return result