From 1de26b346796ec8d6b51b4395017f8107f640c47 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 6 Aug 2021 09:39:59 -0400 Subject: Convert Transaction and Edu object to attrs (#10542) Instead of wrapping the JSON into an object, this creates concrete instances for Transaction and Edu. This allows for improved type hints and simplified code. --- synapse/federation/federation_server.py | 50 +++++++------ synapse/federation/persistence.py | 4 +- synapse/federation/sender/transaction_manager.py | 9 +-- synapse/federation/transport/client.py | 2 +- synapse/federation/transport/server.py | 11 +-- synapse/federation/units.py | 90 +++++++++--------------- 6 files changed, 74 insertions(+), 92 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 145b9161d9..0385aadefa 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -195,13 +195,17 @@ class FederationServer(FederationBase): origin, room_id, versions, limit ) - res = self._transaction_from_pdus(pdus).get_dict() + res = self._transaction_dict_from_pdus(pdus) return 200, res async def on_incoming_transaction( - self, origin: str, transaction_data: JsonDict - ) -> Tuple[int, Dict[str, Any]]: + self, + origin: str, + transaction_id: str, + destination: str, + transaction_data: JsonDict, + ) -> Tuple[int, JsonDict]: # If we receive a transaction we should make sure that kick off handling # any old events in the staging area. if not self._started_handling_of_staged_events: @@ -212,8 +216,14 @@ class FederationServer(FederationBase): # accurate as possible. request_time = self._clock.time_msec() - transaction = Transaction(**transaction_data) - transaction_id = transaction.transaction_id # type: ignore + transaction = Transaction( + transaction_id=transaction_id, + destination=destination, + origin=origin, + origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore + pdus=transaction_data.get("pdus"), # type: ignore + edus=transaction_data.get("edus"), + ) if not transaction_id: raise Exception("Transaction missing transaction_id") @@ -221,9 +231,7 @@ class FederationServer(FederationBase): logger.debug("[%s] Got transaction", transaction_id) # Reject malformed transactions early: reject if too many PDUs/EDUs - if len(transaction.pdus) > 50 or ( # type: ignore - hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore - ): + if len(transaction.pdus) > 50 or len(transaction.edus) > 100: logger.info("Transaction PDU or EDU count too large. Returning 400") return 400, {} @@ -263,7 +271,7 @@ class FederationServer(FederationBase): # CRITICAL SECTION: the first thing we must do (before awaiting) is # add an entry to _active_transactions. assert origin not in self._active_transactions - self._active_transactions[origin] = transaction.transaction_id # type: ignore + self._active_transactions[origin] = transaction.transaction_id try: result = await self._handle_incoming_transaction( @@ -291,11 +299,11 @@ class FederationServer(FederationBase): if response: logger.debug( "[%s] We've already responded to this request", - transaction.transaction_id, # type: ignore + transaction.transaction_id, ) return response - logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore + logger.debug("[%s] Transaction is new", transaction.transaction_id) # We process PDUs and EDUs in parallel. This is important as we don't # want to block things like to device messages from reaching clients @@ -334,7 +342,7 @@ class FederationServer(FederationBase): report back to the sending server. """ - received_pdus_counter.inc(len(transaction.pdus)) # type: ignore + received_pdus_counter.inc(len(transaction.pdus)) origin_host, _ = parse_server_name(origin) @@ -342,7 +350,7 @@ class FederationServer(FederationBase): newest_pdu_ts = 0 - for p in transaction.pdus: # type: ignore + for p in transaction.pdus: # FIXME (richardv): I don't think this works: # https://github.com/matrix-org/synapse/issues/8429 if "unsigned" in p: @@ -436,10 +444,10 @@ class FederationServer(FederationBase): return pdu_results - async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): + async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None: """Process the EDUs in a received transaction.""" - async def _process_edu(edu_dict): + async def _process_edu(edu_dict: JsonDict) -> None: received_edus_counter.inc() edu = Edu( @@ -452,7 +460,7 @@ class FederationServer(FederationBase): await concurrently_execute( _process_edu, - getattr(transaction, "edus", []), + transaction.edus, TRANSACTION_CONCURRENCY_LIMIT, ) @@ -538,7 +546,7 @@ class FederationServer(FederationBase): pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: - return 200, self._transaction_from_pdus([pdu]).get_dict() + return 200, self._transaction_dict_from_pdus([pdu]) else: return 404, "" @@ -879,18 +887,20 @@ class FederationServer(FederationBase): ts_now_ms = self._clock.time_msec() return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) - def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction: + def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict: """Returns a new Transaction containing the given PDUs suitable for transmission. """ time_now = self._clock.time_msec() pdus = [p.get_pdu_json(time_now) for p in pdu_list] return Transaction( + # Just need a dummy transaction ID and destination since it won't be used. + transaction_id="", origin=self.server_name, pdus=pdus, origin_server_ts=int(time_now), - destination=None, - ) + destination="", + ).get_dict() async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: """Process a PDU received in a federation /send/ transaction. diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 2f9c9bc2cd..4fead6ca29 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -45,7 +45,7 @@ class TransactionActions: `None` if we have not previously responded to this transaction or a 2-tuple of `(int, dict)` representing the response code and response body. """ - transaction_id = transaction.transaction_id # type: ignore + transaction_id = transaction.transaction_id if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") @@ -56,7 +56,7 @@ class TransactionActions: self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: """Persist how we responded to a transaction.""" - transaction_id = transaction.transaction_id # type: ignore + transaction_id = transaction.transaction_id if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 72a635830b..dc555cca0b 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -27,6 +27,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.metrics import measure_func @@ -104,13 +105,13 @@ class TransactionManager: len(edus), ) - transaction = Transaction.create_new( + transaction = Transaction( origin_server_ts=int(self.clock.time_msec()), transaction_id=txn_id, origin=self._server_name, destination=destination, - pdus=pdus, - edus=edus, + pdus=[p.get_pdu_json() for p in pdus], + edus=[edu.get_dict() for edu in edus], ) self._next_txn_id += 1 @@ -131,7 +132,7 @@ class TransactionManager: # FIXME (richardv): I also believe it no longer works. We (now?) store # "age_ts" in "unsigned" rather than at the top level. See # https://github.com/matrix-org/synapse/issues/8429. - def json_data_cb(): + def json_data_cb() -> JsonDict: data = transaction.get_dict() now = int(self.clock.time_msec()) if "pdus" in data: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 6a8d3ad4fe..90a7c16b62 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -143,7 +143,7 @@ class TransportLayerClient: """Sends the given Transaction to its destination Args: - transaction (Transaction) + transaction Returns: Succeeds when we get a 2xx HTTP response. The result diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 5e059d6e09..640f46fff6 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -450,21 +450,12 @@ class FederationSendServlet(BaseFederationServerServlet): len(transaction_data.get("edus", [])), ) - # We should ideally be getting this from the security layer. - # origin = body["origin"] - - # Add some extra data to the transaction dict that isn't included - # in the request body. - transaction_data.update( - transaction_id=transaction_id, destination=self.server_name - ) - except Exception as e: logger.exception(e) return 400, {"error": "Invalid transaction"} code, response = await self.handler.on_incoming_transaction( - origin, transaction_data + origin, transaction_id, self.server_name, transaction_data ) return code, response diff --git a/synapse/federation/units.py b/synapse/federation/units.py index c83a261918..b9b12fbea5 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -17,18 +17,17 @@ server protocol. """ import logging -from typing import Optional +from typing import List, Optional import attr from synapse.types import JsonDict -from synapse.util.jsonobject import JsonEncodedObject logger = logging.getLogger(__name__) -@attr.s(slots=True) -class Edu(JsonEncodedObject): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Edu: """An Edu represents a piece of data sent from one homeserver to another. In comparison to Pdus, Edus are not persisted for a long time on disk, are @@ -36,10 +35,10 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - edu_type = attr.ib(type=str) - content = attr.ib(type=dict) - origin = attr.ib(type=str) - destination = attr.ib(type=str) + edu_type: str + content: dict + origin: str + destination: str def get_dict(self) -> JsonDict: return { @@ -55,14 +54,21 @@ class Edu(JsonEncodedObject): "destination": self.destination, } - def get_context(self): + def get_context(self) -> str: return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") - def strip_context(self): + def strip_context(self) -> None: getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}" -class Transaction(JsonEncodedObject): +def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]: + if edus is None: + return [] + return edus + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Transaction: """A transaction is a list of Pdus and Edus to be sent to a remote home server with some extra metadata. @@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject): """ - valid_keys = [ - "transaction_id", - "origin", - "destination", - "origin_server_ts", - "previous_ids", - "pdus", - "edus", - ] - - internal_keys = ["transaction_id", "destination"] - - required_keys = [ - "transaction_id", - "origin", - "destination", - "origin_server_ts", - "pdus", - ] - - def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs): - """If we include a list of pdus then we decode then as PDU's - automatically. - """ - - # If there's no EDUs then remove the arg - if "edus" in kwargs and not kwargs["edus"]: - del kwargs["edus"] - - super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs) - - @staticmethod - def create_new(pdus, **kwargs): - """Used to create a new transaction. Will auto fill out - transaction_id and origin_server_ts keys. - """ - if "origin_server_ts" not in kwargs: - raise KeyError("Require 'origin_server_ts' to construct a Transaction") - if "transaction_id" not in kwargs: - raise KeyError("Require 'transaction_id' to construct a Transaction") - - kwargs["pdus"] = [p.get_pdu_json() for p in pdus] - - return Transaction(**kwargs) + # Required keys. + transaction_id: str + origin: str + destination: str + origin_server_ts: int + pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) + edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) + + def get_dict(self) -> JsonDict: + """A JSON-ready dictionary of valid keys which aren't internal.""" + result = { + "origin": self.origin, + "origin_server_ts": self.origin_server_ts, + "pdus": self.pdus, + } + if self.edus: + result["edus"] = self.edus + return result -- cgit 1.4.1