summary refs log tree commit diff
path: root/synapse/federation/federation_server.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-08-06 09:39:59 -0400
committerGitHub <noreply@github.com>2021-08-06 09:39:59 -0400
commit1de26b346796ec8d6b51b4395017f8107f640c47 (patch)
treea3b5c2d5ce6d179992a24468ffb71b51461973a2 /synapse/federation/federation_server.py
parentFix exceptions in logs when failing to get remote room list (#10541) (diff)
downloadsynapse-1de26b346796ec8d6b51b4395017f8107f640c47.tar.xz
Convert Transaction and Edu object to attrs (#10542)
Instead of wrapping the JSON into an object, this creates concrete
instances for Transaction and Edu. This allows for improved type
hints and simplified code.
Diffstat (limited to 'synapse/federation/federation_server.py')
-rw-r--r--synapse/federation/federation_server.py50
1 files changed, 30 insertions, 20 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 145b9161d9..0385aadefa 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -195,13 +195,17 @@ class FederationServer(FederationBase):
                 origin, room_id, versions, limit
             )
 
-            res = self._transaction_from_pdus(pdus).get_dict()
+            res = self._transaction_dict_from_pdus(pdus)
 
         return 200, res
 
     async def on_incoming_transaction(
-        self, origin: str, transaction_data: JsonDict
-    ) -> Tuple[int, Dict[str, Any]]:
+        self,
+        origin: str,
+        transaction_id: str,
+        destination: str,
+        transaction_data: JsonDict,
+    ) -> Tuple[int, JsonDict]:
         # If we receive a transaction we should make sure that kick off handling
         # any old events in the staging area.
         if not self._started_handling_of_staged_events:
@@ -212,8 +216,14 @@ class FederationServer(FederationBase):
         # accurate as possible.
         request_time = self._clock.time_msec()
 
-        transaction = Transaction(**transaction_data)
-        transaction_id = transaction.transaction_id  # type: ignore
+        transaction = Transaction(
+            transaction_id=transaction_id,
+            destination=destination,
+            origin=origin,
+            origin_server_ts=transaction_data.get("origin_server_ts"),  # type: ignore
+            pdus=transaction_data.get("pdus"),  # type: ignore
+            edus=transaction_data.get("edus"),
+        )
 
         if not transaction_id:
             raise Exception("Transaction missing transaction_id")
@@ -221,9 +231,7 @@ class FederationServer(FederationBase):
         logger.debug("[%s] Got transaction", transaction_id)
 
         # Reject malformed transactions early: reject if too many PDUs/EDUs
-        if len(transaction.pdus) > 50 or (  # type: ignore
-            hasattr(transaction, "edus") and len(transaction.edus) > 100  # type: ignore
-        ):
+        if len(transaction.pdus) > 50 or len(transaction.edus) > 100:
             logger.info("Transaction PDU or EDU count too large. Returning 400")
             return 400, {}
 
@@ -263,7 +271,7 @@ class FederationServer(FederationBase):
         # CRITICAL SECTION: the first thing we must do (before awaiting) is
         # add an entry to _active_transactions.
         assert origin not in self._active_transactions
-        self._active_transactions[origin] = transaction.transaction_id  # type: ignore
+        self._active_transactions[origin] = transaction.transaction_id
 
         try:
             result = await self._handle_incoming_transaction(
@@ -291,11 +299,11 @@ class FederationServer(FederationBase):
         if response:
             logger.debug(
                 "[%s] We've already responded to this request",
-                transaction.transaction_id,  # type: ignore
+                transaction.transaction_id,
             )
             return response
 
-        logger.debug("[%s] Transaction is new", transaction.transaction_id)  # type: ignore
+        logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
         # We process PDUs and EDUs in parallel. This is important as we don't
         # want to block things like to device messages from reaching clients
@@ -334,7 +342,7 @@ class FederationServer(FederationBase):
             report back to the sending server.
         """
 
-        received_pdus_counter.inc(len(transaction.pdus))  # type: ignore
+        received_pdus_counter.inc(len(transaction.pdus))
 
         origin_host, _ = parse_server_name(origin)
 
@@ -342,7 +350,7 @@ class FederationServer(FederationBase):
 
         newest_pdu_ts = 0
 
-        for p in transaction.pdus:  # type: ignore
+        for p in transaction.pdus:
             # FIXME (richardv): I don't think this works:
             #  https://github.com/matrix-org/synapse/issues/8429
             if "unsigned" in p:
@@ -436,10 +444,10 @@ class FederationServer(FederationBase):
 
         return pdu_results
 
-    async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
+    async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None:
         """Process the EDUs in a received transaction."""
 
-        async def _process_edu(edu_dict):
+        async def _process_edu(edu_dict: JsonDict) -> None:
             received_edus_counter.inc()
 
             edu = Edu(
@@ -452,7 +460,7 @@ class FederationServer(FederationBase):
 
         await concurrently_execute(
             _process_edu,
-            getattr(transaction, "edus", []),
+            transaction.edus,
             TRANSACTION_CONCURRENCY_LIMIT,
         )
 
@@ -538,7 +546,7 @@ class FederationServer(FederationBase):
         pdu = await self.handler.get_persisted_pdu(origin, event_id)
 
         if pdu:
-            return 200, self._transaction_from_pdus([pdu]).get_dict()
+            return 200, self._transaction_dict_from_pdus([pdu])
         else:
             return 404, ""
 
@@ -879,18 +887,20 @@ class FederationServer(FederationBase):
         ts_now_ms = self._clock.time_msec()
         return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
 
-    def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction:
+    def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict:
         """Returns a new Transaction containing the given PDUs suitable for
         transmission.
         """
         time_now = self._clock.time_msec()
         pdus = [p.get_pdu_json(time_now) for p in pdu_list]
         return Transaction(
+            # Just need a dummy transaction ID and destination since it won't be used.
+            transaction_id="",
             origin=self.server_name,
             pdus=pdus,
             origin_server_ts=int(time_now),
-            destination=None,
-        )
+            destination="",
+        ).get_dict()
 
     async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
         """Process a PDU received in a federation /send/ transaction.