diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 145b9161d9..0385aadefa 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -195,13 +195,17 @@ class FederationServer(FederationBase):
origin, room_id, versions, limit
)
- res = self._transaction_from_pdus(pdus).get_dict()
+ res = self._transaction_dict_from_pdus(pdus)
return 200, res
async def on_incoming_transaction(
- self, origin: str, transaction_data: JsonDict
- ) -> Tuple[int, Dict[str, Any]]:
+ self,
+ origin: str,
+ transaction_id: str,
+ destination: str,
+ transaction_data: JsonDict,
+ ) -> Tuple[int, JsonDict]:
# If we receive a transaction we should make sure that kick off handling
# any old events in the staging area.
if not self._started_handling_of_staged_events:
@@ -212,8 +216,14 @@ class FederationServer(FederationBase):
# accurate as possible.
request_time = self._clock.time_msec()
- transaction = Transaction(**transaction_data)
- transaction_id = transaction.transaction_id # type: ignore
+ transaction = Transaction(
+ transaction_id=transaction_id,
+ destination=destination,
+ origin=origin,
+ origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore
+ pdus=transaction_data.get("pdus"), # type: ignore
+ edus=transaction_data.get("edus"),
+ )
if not transaction_id:
raise Exception("Transaction missing transaction_id")
@@ -221,9 +231,7 @@ class FederationServer(FederationBase):
logger.debug("[%s] Got transaction", transaction_id)
# Reject malformed transactions early: reject if too many PDUs/EDUs
- if len(transaction.pdus) > 50 or ( # type: ignore
- hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
- ):
+ if len(transaction.pdus) > 50 or len(transaction.edus) > 100:
logger.info("Transaction PDU or EDU count too large. Returning 400")
return 400, {}
@@ -263,7 +271,7 @@ class FederationServer(FederationBase):
# CRITICAL SECTION: the first thing we must do (before awaiting) is
# add an entry to _active_transactions.
assert origin not in self._active_transactions
- self._active_transactions[origin] = transaction.transaction_id # type: ignore
+ self._active_transactions[origin] = transaction.transaction_id
try:
result = await self._handle_incoming_transaction(
@@ -291,11 +299,11 @@ class FederationServer(FederationBase):
if response:
logger.debug(
"[%s] We've already responded to this request",
- transaction.transaction_id, # type: ignore
+ transaction.transaction_id,
)
return response
- logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
+ logger.debug("[%s] Transaction is new", transaction.transaction_id)
# We process PDUs and EDUs in parallel. This is important as we don't
# want to block things like to device messages from reaching clients
@@ -334,7 +342,7 @@ class FederationServer(FederationBase):
report back to the sending server.
"""
- received_pdus_counter.inc(len(transaction.pdus)) # type: ignore
+ received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(origin)
@@ -342,7 +350,7 @@ class FederationServer(FederationBase):
newest_pdu_ts = 0
- for p in transaction.pdus: # type: ignore
+ for p in transaction.pdus:
# FIXME (richardv): I don't think this works:
# https://github.com/matrix-org/synapse/issues/8429
if "unsigned" in p:
@@ -436,10 +444,10 @@ class FederationServer(FederationBase):
return pdu_results
- async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
+ async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None:
"""Process the EDUs in a received transaction."""
- async def _process_edu(edu_dict):
+ async def _process_edu(edu_dict: JsonDict) -> None:
received_edus_counter.inc()
edu = Edu(
@@ -452,7 +460,7 @@ class FederationServer(FederationBase):
await concurrently_execute(
_process_edu,
- getattr(transaction, "edus", []),
+ transaction.edus,
TRANSACTION_CONCURRENCY_LIMIT,
)
@@ -538,7 +546,7 @@ class FederationServer(FederationBase):
pdu = await self.handler.get_persisted_pdu(origin, event_id)
if pdu:
- return 200, self._transaction_from_pdus([pdu]).get_dict()
+ return 200, self._transaction_dict_from_pdus([pdu])
else:
return 404, ""
@@ -879,18 +887,20 @@ class FederationServer(FederationBase):
ts_now_ms = self._clock.time_msec()
return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
- def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction:
+ def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict:
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
time_now = self._clock.time_msec()
pdus = [p.get_pdu_json(time_now) for p in pdu_list]
return Transaction(
+ # Just need a dummy transaction ID and destination since it won't be used.
+ transaction_id="",
origin=self.server_name,
pdus=pdus,
origin_server_ts=int(time_now),
- destination=None,
- )
+ destination="",
+ ).get_dict()
async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
"""Process a PDU received in a federation /send/ transaction.
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 2f9c9bc2cd..4fead6ca29 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -45,7 +45,7 @@ class TransactionActions:
`None` if we have not previously responded to this transaction or a
2-tuple of `(int, dict)` representing the response code and response body.
"""
- transaction_id = transaction.transaction_id # type: ignore
+ transaction_id = transaction.transaction_id
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id")
@@ -56,7 +56,7 @@ class TransactionActions:
self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None:
"""Persist how we responded to a transaction."""
- transaction_id = transaction.transaction_id # type: ignore
+ transaction_id = transaction.transaction_id
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id")
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 72a635830b..dc555cca0b 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -27,6 +27,7 @@ from synapse.logging.opentracing import (
tags,
whitelisted_homeserver,
)
+from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.metrics import measure_func
@@ -104,13 +105,13 @@ class TransactionManager:
len(edus),
)
- transaction = Transaction.create_new(
+ transaction = Transaction(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self._server_name,
destination=destination,
- pdus=pdus,
- edus=edus,
+ pdus=[p.get_pdu_json() for p in pdus],
+ edus=[edu.get_dict() for edu in edus],
)
self._next_txn_id += 1
@@ -131,7 +132,7 @@ class TransactionManager:
# FIXME (richardv): I also believe it no longer works. We (now?) store
# "age_ts" in "unsigned" rather than at the top level. See
# https://github.com/matrix-org/synapse/issues/8429.
- def json_data_cb():
+ def json_data_cb() -> JsonDict:
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 6a8d3ad4fe..90a7c16b62 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -143,7 +143,7 @@ class TransportLayerClient:
"""Sends the given Transaction to its destination
Args:
- transaction (Transaction)
+ transaction
Returns:
Succeeds when we get a 2xx HTTP response. The result
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 5e059d6e09..640f46fff6 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -450,21 +450,12 @@ class FederationSendServlet(BaseFederationServerServlet):
len(transaction_data.get("edus", [])),
)
- # We should ideally be getting this from the security layer.
- # origin = body["origin"]
-
- # Add some extra data to the transaction dict that isn't included
- # in the request body.
- transaction_data.update(
- transaction_id=transaction_id, destination=self.server_name
- )
-
except Exception as e:
logger.exception(e)
return 400, {"error": "Invalid transaction"}
code, response = await self.handler.on_incoming_transaction(
- origin, transaction_data
+ origin, transaction_id, self.server_name, transaction_data
)
return code, response
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index c83a261918..b9b12fbea5 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -17,18 +17,17 @@ server protocol.
"""
import logging
-from typing import Optional
+from typing import List, Optional
import attr
from synapse.types import JsonDict
-from synapse.util.jsonobject import JsonEncodedObject
logger = logging.getLogger(__name__)
-@attr.s(slots=True)
-class Edu(JsonEncodedObject):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Edu:
"""An Edu represents a piece of data sent from one homeserver to another.
In comparison to Pdus, Edus are not persisted for a long time on disk, are
@@ -36,10 +35,10 @@ class Edu(JsonEncodedObject):
internal ID or previous references graph.
"""
- edu_type = attr.ib(type=str)
- content = attr.ib(type=dict)
- origin = attr.ib(type=str)
- destination = attr.ib(type=str)
+ edu_type: str
+ content: dict
+ origin: str
+ destination: str
def get_dict(self) -> JsonDict:
return {
@@ -55,14 +54,21 @@ class Edu(JsonEncodedObject):
"destination": self.destination,
}
- def get_context(self):
+ def get_context(self) -> str:
return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")
- def strip_context(self):
+ def strip_context(self) -> None:
getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"
-class Transaction(JsonEncodedObject):
+def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]:
+ if edus is None:
+ return []
+ return edus
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Transaction:
"""A transaction is a list of Pdus and Edus to be sent to a remote home
server with some extra metadata.
@@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject):
"""
- valid_keys = [
- "transaction_id",
- "origin",
- "destination",
- "origin_server_ts",
- "previous_ids",
- "pdus",
- "edus",
- ]
-
- internal_keys = ["transaction_id", "destination"]
-
- required_keys = [
- "transaction_id",
- "origin",
- "destination",
- "origin_server_ts",
- "pdus",
- ]
-
- def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs):
- """If we include a list of pdus then we decode then as PDU's
- automatically.
- """
-
- # If there's no EDUs then remove the arg
- if "edus" in kwargs and not kwargs["edus"]:
- del kwargs["edus"]
-
- super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs)
-
- @staticmethod
- def create_new(pdus, **kwargs):
- """Used to create a new transaction. Will auto fill out
- transaction_id and origin_server_ts keys.
- """
- if "origin_server_ts" not in kwargs:
- raise KeyError("Require 'origin_server_ts' to construct a Transaction")
- if "transaction_id" not in kwargs:
- raise KeyError("Require 'transaction_id' to construct a Transaction")
-
- kwargs["pdus"] = [p.get_pdu_json() for p in pdus]
-
- return Transaction(**kwargs)
+ # Required keys.
+ transaction_id: str
+ origin: str
+ destination: str
+ origin_server_ts: int
+ pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
+ edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
+
+ def get_dict(self) -> JsonDict:
+ """A JSON-ready dictionary of valid keys which aren't internal."""
+ result = {
+ "origin": self.origin,
+ "origin_server_ts": self.origin_server_ts,
+ "pdus": self.pdus,
+ }
+ if self.edus:
+ result["edus"] = self.edus
+ return result
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
deleted file mode 100644
index abc12f0837..0000000000
--- a/synapse/util/jsonobject.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-class JsonEncodedObject:
- """A common base class for defining protocol units that are represented
- as JSON.
-
- Attributes:
- unrecognized_keys (dict): A dict containing all the key/value pairs we
- don't recognize.
- """
-
- valid_keys = [] # keys we will store
- """A list of strings that represent keys we know about
- and can handle. If we have values for these keys they will be
- included in the `dictionary` instance variable.
- """
-
- internal_keys = [] # keys to ignore while building dict
- """A list of strings that should *not* be encoded into JSON.
- """
-
- required_keys = []
- """A list of strings that we require to exist. If they are not given upon
- construction it raises an exception.
- """
-
- def __init__(self, **kwargs):
- """Takes the dict of `kwargs` and loads all keys that are *valid*
- (i.e., are included in the `valid_keys` list) into the dictionary`
- instance variable.
-
- Any keys that aren't recognized are added to the `unrecognized_keys`
- attribute.
-
- Args:
- **kwargs: Attributes associated with this protocol unit.
- """
- for required_key in self.required_keys:
- if required_key not in kwargs:
- raise RuntimeError("Key %s is required" % required_key)
-
- self.unrecognized_keys = {} # Keys we were given not listed as valid
- for k, v in kwargs.items():
- if k in self.valid_keys or k in self.internal_keys:
- self.__dict__[k] = v
- else:
- self.unrecognized_keys[k] = v
-
- def get_dict(self):
- """Converts this protocol unit into a :py:class:`dict`, ready to be
- encoded as JSON.
-
- The keys it encodes are: `valid_keys` - `internal_keys`
-
- Returns
- dict
- """
- d = {
- k: _encode(v)
- for (k, v) in self.__dict__.items()
- if k in self.valid_keys and k not in self.internal_keys
- }
- d.update(self.unrecognized_keys)
- return d
-
- def get_internal_dict(self):
- d = {
- k: _encode(v, internal=True)
- for (k, v) in self.__dict__.items()
- if k in self.valid_keys
- }
- d.update(self.unrecognized_keys)
- return d
-
- def __str__(self):
- return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
-
-
-def _encode(obj, internal=False):
- if type(obj) is list:
- return [_encode(o, internal=internal) for o in obj]
-
- if isinstance(obj, JsonEncodedObject):
- if internal:
- return obj.get_internal_dict()
- else:
- return obj.get_dict()
-
- return obj
|