diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index c83a261918..b9b12fbea5 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -17,18 +17,17 @@ server protocol.
"""
import logging
-from typing import Optional
+from typing import List, Optional
import attr
from synapse.types import JsonDict
-from synapse.util.jsonobject import JsonEncodedObject
logger = logging.getLogger(__name__)
-@attr.s(slots=True)
-class Edu(JsonEncodedObject):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Edu:
"""An Edu represents a piece of data sent from one homeserver to another.
In comparison to Pdus, Edus are not persisted for a long time on disk, are
@@ -36,10 +35,10 @@ class Edu(JsonEncodedObject):
internal ID or previous references graph.
"""
- edu_type = attr.ib(type=str)
- content = attr.ib(type=dict)
- origin = attr.ib(type=str)
- destination = attr.ib(type=str)
+ edu_type: str
+ content: dict
+ origin: str
+ destination: str
def get_dict(self) -> JsonDict:
return {
@@ -55,14 +54,21 @@ class Edu(JsonEncodedObject):
"destination": self.destination,
}
- def get_context(self):
+ def get_context(self) -> str:
return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")
- def strip_context(self):
+ def strip_context(self) -> None:
getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"
-class Transaction(JsonEncodedObject):
+def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]:
+ if edus is None:
+ return []
+ return edus
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Transaction:
"""A transaction is a list of Pdus and Edus to be sent to a remote home
server with some extra metadata.
@@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject):
"""
- valid_keys = [
- "transaction_id",
- "origin",
- "destination",
- "origin_server_ts",
- "previous_ids",
- "pdus",
- "edus",
- ]
-
- internal_keys = ["transaction_id", "destination"]
-
- required_keys = [
- "transaction_id",
- "origin",
- "destination",
- "origin_server_ts",
- "pdus",
- ]
-
- def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs):
- """If we include a list of pdus then we decode then as PDU's
- automatically.
- """
-
- # If there's no EDUs then remove the arg
- if "edus" in kwargs and not kwargs["edus"]:
- del kwargs["edus"]
-
- super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs)
-
- @staticmethod
- def create_new(pdus, **kwargs):
- """Used to create a new transaction. Will auto fill out
- transaction_id and origin_server_ts keys.
- """
- if "origin_server_ts" not in kwargs:
- raise KeyError("Require 'origin_server_ts' to construct a Transaction")
- if "transaction_id" not in kwargs:
- raise KeyError("Require 'transaction_id' to construct a Transaction")
-
- kwargs["pdus"] = [p.get_pdu_json() for p in pdus]
-
- return Transaction(**kwargs)
+ # Required keys.
+ transaction_id: str
+ origin: str
+ destination: str
+ origin_server_ts: int
+ pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
+ edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
+
+ def get_dict(self) -> JsonDict:
+ """A JSON-ready dictionary of valid keys which aren't internal."""
+ result = {
+ "origin": self.origin,
+ "origin_server_ts": self.origin_server_ts,
+ "pdus": self.pdus,
+ }
+ if self.edus:
+ result["edus"] = self.edus
+ return result
|