summary refs log tree commit diff
path: root/synapse/federation/units.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-08-06 09:39:59 -0400
committerGitHub <noreply@github.com>2021-08-06 09:39:59 -0400
commit1de26b346796ec8d6b51b4395017f8107f640c47 (patch)
treea3b5c2d5ce6d179992a24468ffb71b51461973a2 /synapse/federation/units.py
parentFix exceptions in logs when failing to get remote room list (#10541) (diff)
downloadsynapse-1de26b346796ec8d6b51b4395017f8107f640c47.tar.xz
Convert Transaction and Edu object to attrs (#10542)
Instead of wrapping the JSON into an object, this creates concrete
instances for Transaction and Edu. This allows for improved type
hints and simplified code.
Diffstat (limited to 'synapse/federation/units.py')
-rw-r--r--synapse/federation/units.py90
1 files changed, 35 insertions, 55 deletions
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index c83a261918..b9b12fbea5 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -17,18 +17,17 @@ server protocol.
 """
 
 import logging
-from typing import Optional
+from typing import List, Optional
 
 import attr
 
 from synapse.types import JsonDict
-from synapse.util.jsonobject import JsonEncodedObject
 
 logger = logging.getLogger(__name__)
 
 
-@attr.s(slots=True)
-class Edu(JsonEncodedObject):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Edu:
     """An Edu represents a piece of data sent from one homeserver to another.
 
     In comparison to Pdus, Edus are not persisted for a long time on disk, are
@@ -36,10 +35,10 @@ class Edu(JsonEncodedObject):
     internal ID or previous references graph.
     """
 
-    edu_type = attr.ib(type=str)
-    content = attr.ib(type=dict)
-    origin = attr.ib(type=str)
-    destination = attr.ib(type=str)
+    edu_type: str
+    content: dict
+    origin: str
+    destination: str
 
     def get_dict(self) -> JsonDict:
         return {
@@ -55,14 +54,21 @@ class Edu(JsonEncodedObject):
             "destination": self.destination,
         }
 
-    def get_context(self):
+    def get_context(self) -> str:
         return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")
 
-    def strip_context(self):
+    def strip_context(self) -> None:
         getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"
 
 
-class Transaction(JsonEncodedObject):
+def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]:
+    if edus is None:
+        return []
+    return edus
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Transaction:
     """A transaction is a list of Pdus and Edus to be sent to a remote home
     server with some extra metadata.
 
@@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject):
 
     """
 
-    valid_keys = [
-        "transaction_id",
-        "origin",
-        "destination",
-        "origin_server_ts",
-        "previous_ids",
-        "pdus",
-        "edus",
-    ]
-
-    internal_keys = ["transaction_id", "destination"]
-
-    required_keys = [
-        "transaction_id",
-        "origin",
-        "destination",
-        "origin_server_ts",
-        "pdus",
-    ]
-
-    def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs):
-        """If we include a list of pdus then we decode then as PDU's
-        automatically.
-        """
-
-        # If there's no EDUs then remove the arg
-        if "edus" in kwargs and not kwargs["edus"]:
-            del kwargs["edus"]
-
-        super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs)
-
-    @staticmethod
-    def create_new(pdus, **kwargs):
-        """Used to create a new transaction. Will auto fill out
-        transaction_id and origin_server_ts keys.
-        """
-        if "origin_server_ts" not in kwargs:
-            raise KeyError("Require 'origin_server_ts' to construct a Transaction")
-        if "transaction_id" not in kwargs:
-            raise KeyError("Require 'transaction_id' to construct a Transaction")
-
-        kwargs["pdus"] = [p.get_pdu_json() for p in pdus]
-
-        return Transaction(**kwargs)
+    # Required keys.
+    transaction_id: str
+    origin: str
+    destination: str
+    origin_server_ts: int
+    pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
+    edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
+
+    def get_dict(self) -> JsonDict:
+        """A JSON-ready dictionary of valid keys which aren't internal."""
+        result = {
+            "origin": self.origin,
+            "origin_server_ts": self.origin_server_ts,
+            "pdus": self.pdus,
+        }
+        if self.edus:
+            result["edus"] = self.edus
+        return result