diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
new file mode 100644
index 0000000000..b4d95ed5ac
--- /dev/null
+++ b/synapse/federation/__init__.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" This package includes all the federation specific logic.
+"""
+
+from .replication import ReplicationLayer
+from .transport import TransportLayer
+
+
+def initialize_http_replication(homeserver):
+ transport = TransportLayer(
+ homeserver.hostname,
+ server=homeserver.get_http_server(),
+ client=homeserver.get_http_client()
+ )
+
+ return ReplicationLayer(homeserver, transport)
diff --git a/synapse/federation/handler.py b/synapse/federation/handler.py
new file mode 100644
index 0000000000..31e8470b33
--- /dev/null
+++ b/synapse/federation/handler.py
@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from .pdu_codec import PduCodec
+
+from synapse.api.errors import AuthError
+from synapse.util.logutils import log_function
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class FederationEventHandler(object):
+ """ Responsible for:
+ a) handling received Pdus before handing them on as Events to the rest
+ of the home server (including auth and state conflict resoultion)
+ b) converting events that were produced by local clients that may need
+ to be sent to remote home servers.
+ """
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.replication_layer = hs.get_replication_layer()
+ self.state_handler = hs.get_state_handler()
+ # self.auth_handler = gs.get_auth_handler()
+ self.event_handler = hs.get_handlers().federation_handler
+ self.server_name = hs.hostname
+
+ self.lock_manager = hs.get_room_lock_manager()
+
+ self.replication_layer.set_handler(self)
+
+ self.pdu_codec = PduCodec(hs)
+
+ @log_function
+ @defer.inlineCallbacks
+ def handle_new_event(self, event):
+ """ Takes in an event from the client to server side, that has already
+ been authed and handled by the state module, and sends it to any
+ remote home servers that may be interested.
+
+ Args:
+ event
+
+ Returns:
+ Deferred: Resolved when it has successfully been queued for
+ processing.
+ """
+ yield self._fill_out_prev_events(event)
+
+ pdu = self.pdu_codec.pdu_from_event(event)
+
+ if not hasattr(pdu, "destinations") or not pdu.destinations:
+ pdu.destinations = []
+
+ yield self.replication_layer.send_pdu(pdu)
+
+ @log_function
+ @defer.inlineCallbacks
+ def backfill(self, room_id, limit):
+ # TODO: Work out which destinations to ask for pagination
+ # self.replication_layer.paginate(dest, room_id, limit)
+ pass
+
+ @log_function
+ def get_state_for_room(self, destination, room_id):
+ return self.replication_layer.get_state_for_context(
+ destination, room_id
+ )
+
+ @log_function
+ @defer.inlineCallbacks
+ def on_receive_pdu(self, pdu):
+ """ Called by the ReplicationLayer when we have a new pdu. We need to
+ do auth checks and put it throught the StateHandler.
+ """
+ event = self.pdu_codec.event_from_pdu(pdu)
+
+ try:
+ with (yield self.lock_manager.lock(pdu.context)):
+ if event.is_state:
+ is_new_state = yield self.state_handler.handle_new_state(
+ pdu
+ )
+ if not is_new_state:
+ return
+ else:
+ is_new_state = False
+
+ yield self.event_handler.on_receive(event, is_new_state)
+
+ except AuthError:
+ # TODO: Implement something in federation that allows us to
+ # respond to PDU.
+ raise
+
+ return
+
+ @defer.inlineCallbacks
+ def _on_new_state(self, pdu, new_state_event):
+ # TODO: Do any store stuff here. Notifiy C2S about this new
+ # state.
+
+ yield self.store.update_current_state(
+ pdu_id=pdu.pdu_id,
+ origin=pdu.origin,
+ context=pdu.context,
+ pdu_type=pdu.pdu_type,
+ state_key=pdu.state_key
+ )
+
+ yield self.event_handler.on_receive(new_state_event)
+
+ @defer.inlineCallbacks
+ def _fill_out_prev_events(self, event):
+ if hasattr(event, "prev_events"):
+ return
+
+ results = yield self.store.get_latest_pdus_in_context(
+ event.room_id
+ )
+
+ es = [
+ "%s@%s" % (p_id, origin) for p_id, origin, _ in results
+ ]
+
+ event.prev_events = [e for e in es if e != event.event_id]
+
+ if results:
+ event.depth = max([int(v) for _, _, v in results]) + 1
+ else:
+ event.depth = 0
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py
new file mode 100644
index 0000000000..9155930e47
--- /dev/null
+++ b/synapse/federation/pdu_codec.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .units import Pdu
+
+import copy
+
+
+def decode_event_id(event_id, server_name):
+ parts = event_id.split("@")
+ if len(parts) < 2:
+ return (event_id, server_name)
+ else:
+ return (parts[0], "".join(parts[1:]))
+
+
+def encode_event_id(pdu_id, origin):
+ return "%s@%s" % (pdu_id, origin)
+
+
+class PduCodec(object):
+
+ def __init__(self, hs):
+ self.server_name = hs.hostname
+ self.event_factory = hs.get_event_factory()
+ self.clock = hs.get_clock()
+
+ def event_from_pdu(self, pdu):
+ kwargs = {}
+
+ kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
+ kwargs["room_id"] = pdu.context
+ kwargs["etype"] = pdu.pdu_type
+ kwargs["prev_events"] = [
+ encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
+ ]
+
+ if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
+ kwargs["prev_state"] = encode_event_id(
+ pdu.prev_state_id, pdu.prev_state_origin
+ )
+
+ kwargs.update({
+ k: v
+ for k, v in pdu.get_full_dict().items()
+ if k not in [
+ "pdu_id",
+ "context",
+ "pdu_type",
+ "prev_pdus",
+ "prev_state_id",
+ "prev_state_origin",
+ ]
+ })
+
+ return self.event_factory.create_event(**kwargs)
+
+ def pdu_from_event(self, event):
+ d = event.get_full_dict()
+
+ d["pdu_id"], d["origin"] = decode_event_id(
+ event.event_id, self.server_name
+ )
+ d["context"] = event.room_id
+ d["pdu_type"] = event.type
+
+ if hasattr(event, "prev_events"):
+ d["prev_pdus"] = [
+ decode_event_id(e, self.server_name)
+ for e in event.prev_events
+ ]
+
+ if hasattr(event, "prev_state"):
+ d["prev_state_id"], d["prev_state_origin"] = (
+ decode_event_id(event.prev_state, self.server_name)
+ )
+
+ if hasattr(event, "state_key"):
+ d["is_state"] = True
+
+ kwargs = copy.deepcopy(event.unrecognized_keys)
+ kwargs.update({
+ k: v for k, v in d.items()
+ if k not in ["event_id", "room_id", "type", "prev_events"]
+ })
+
+ if "ts" not in kwargs:
+ kwargs["ts"] = int(self.clock.time_msec())
+
+ return Pdu(**kwargs)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
new file mode 100644
index 0000000000..ad4111c683
--- /dev/null
+++ b/synapse/federation/persistence.py
@@ -0,0 +1,240 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" This module contains all the persistence actions done by the federation
+package.
+
+These actions are mostly only used by the :py:mod:`.replication` module.
+"""
+
+from twisted.internet import defer
+
+from .units import Pdu
+
+from synapse.util.logutils import log_function
+
+import copy
+import json
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class PduActions(object):
+ """ Defines persistence actions that relate to handling PDUs.
+ """
+
+ def __init__(self, datastore):
+ self.store = datastore
+
+ @log_function
+ def persist_received(self, pdu):
+ """ Persists the given `Pdu` that was received from a remote home
+ server.
+
+ Returns:
+ Deferred
+ """
+ return self._persist(pdu)
+
+ @defer.inlineCallbacks
+ @log_function
+ def persist_outgoing(self, pdu):
+ """ Persists the given `Pdu` that this home server created.
+
+ Returns:
+ Deferred
+ """
+ ret = yield self._persist(pdu)
+
+ defer.returnValue(ret)
+
+ @log_function
+ def mark_as_processed(self, pdu):
+ """ Persist the fact that we have fully processed the given `Pdu`
+
+ Returns:
+ Deferred
+ """
+ return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
+
+ @defer.inlineCallbacks
+ @log_function
+ def populate_previous_pdus(self, pdu):
+ """ Given an outgoing `Pdu` fill out its `prev_ids` key with the `Pdu`s
+ that we have received.
+
+ Returns:
+ Deferred
+ """
+ results = yield self.store.get_latest_pdus_in_context(pdu.context)
+
+ pdu.prev_pdus = [(p_id, origin) for p_id, origin, _ in results]
+
+ vs = [int(v) for _, _, v in results]
+ if vs:
+ pdu.depth = max(vs) + 1
+ else:
+ pdu.depth = 0
+
+ @defer.inlineCallbacks
+ @log_function
+ def after_transaction(self, transaction_id, destination, origin):
+ """ Returns all `Pdu`s that we sent to the given remote home server
+ after a given transaction id.
+
+ Returns:
+ Deferred: Results in a list of `Pdu`s
+ """
+ results = yield self.store.get_pdus_after_transaction(
+ transaction_id,
+ destination
+ )
+
+ defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_all_pdus_from_context(self, context):
+ results = yield self.store.get_all_pdus_from_context(context)
+ defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
+
+ @defer.inlineCallbacks
+ @log_function
+ def paginate(self, context, pdu_list, limit):
+ """ For a given list of PDU id and origins return the proceeding
+ `limit` `Pdu`s in the given `context`.
+
+ Returns:
+ Deferred: Results in a list of `Pdu`s.
+ """
+ results = yield self.store.get_pagination(
+ context, pdu_list, limit
+ )
+
+ defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
+
+ @log_function
+ def is_new(self, pdu):
+ """ When we receive a `Pdu` from a remote home server, we want to
+ figure out whether it is `new`, i.e. it is not some historic PDU that
+ we haven't seen simply because we haven't paginated back that far.
+
+ Returns:
+ Deferred: Results in a `bool`
+ """
+ return self.store.is_pdu_new(
+ pdu_id=pdu.pdu_id,
+ origin=pdu.origin,
+ context=pdu.context,
+ depth=pdu.depth
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _persist(self, pdu):
+ kwargs = copy.copy(pdu.__dict__)
+ unrec_keys = copy.copy(pdu.unrecognized_keys)
+ del kwargs["content"]
+ kwargs["content_json"] = json.dumps(pdu.content)
+ kwargs["unrecognized_keys"] = json.dumps(unrec_keys)
+
+ logger.debug("Persisting: %s", repr(kwargs))
+
+ if pdu.is_state:
+ ret = yield self.store.persist_state(**kwargs)
+ else:
+ ret = yield self.store.persist_pdu(**kwargs)
+
+ yield self.store.update_min_depth_for_context(
+ pdu.context, pdu.depth
+ )
+
+ defer.returnValue(ret)
+
+
+class TransactionActions(object):
+ """ Defines persistence actions that relate to handling Transactions.
+ """
+
+ def __init__(self, datastore):
+ self.store = datastore
+
+ @log_function
+ def have_responded(self, transaction):
+ """ Have we already responded to a transaction with the same id and
+ origin?
+
+ Returns:
+ Deferred: Results in `None` if we have not previously responded to
+ this transaction or a 2-tuple of `(int, dict)` representing the
+ response code and response body.
+ """
+ if not transaction.transaction_id:
+ raise RuntimeError("Cannot persist a transaction with no "
+ "transaction_id")
+
+ return self.store.get_received_txn_response(
+ transaction.transaction_id, transaction.origin
+ )
+
+ @log_function
+ def set_response(self, transaction, code, response):
+ """ Persist how we responded to a transaction.
+
+ Returns:
+ Deferred
+ """
+ if not transaction.transaction_id:
+ raise RuntimeError("Cannot persist a transaction with no "
+ "transaction_id")
+
+ return self.store.set_received_txn_response(
+ transaction.transaction_id,
+ transaction.origin,
+ code,
+ json.dumps(response)
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def prepare_to_send(self, transaction):
+ """ Persists the `Transaction` we are about to send and works out the
+ correct value for the `prev_ids` key.
+
+ Returns:
+ Deferred
+ """
+ transaction.prev_ids = yield self.store.prep_send_transaction(
+ transaction.transaction_id,
+ transaction.destination,
+ transaction.ts,
+ [(p["pdu_id"], p["origin"]) for p in transaction.pdus]
+ )
+
+ @log_function
+ def delivered(self, transaction, response_code, response_dict):
+ """ Marks the given `Transaction` as having been successfully
+ delivered to the remote homeserver, and what the response was.
+
+ Returns:
+ Deferred
+ """
+ return self.store.delivered_txn(
+ transaction.transaction_id,
+ transaction.destination,
+ response_code,
+ json.dumps(response_dict)
+ )
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
new file mode 100644
index 0000000000..0f5b974291
--- /dev/null
+++ b/synapse/federation/replication.py
@@ -0,0 +1,582 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This layer is responsible for replicating with remote home servers using
+a given transport.
+"""
+
+from twisted.internet import defer
+
+from .units import Transaction, Pdu, Edu
+
+from .persistence import PduActions, TransactionActions
+
+from synapse.util.logutils import log_function
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationLayer(object):
+ """This layer is responsible for replicating with remote home servers over
+ the given transport. I.e., does the sending and receiving of PDUs to
+ remote home servers.
+
+ The layer communicates with the rest of the server via a registered
+ ReplicationHandler.
+
+ In more detail, the layer:
+ * Receives incoming data and processes it into transactions and pdus.
+ * Fetches any PDUs it thinks it might have missed.
+ * Keeps the current state for contexts up to date by applying the
+ suitable conflict resolution.
+ * Sends outgoing pdus wrapped in transactions.
+ * Fills out the references to previous pdus/transactions appropriately
+ for outgoing data.
+ """
+
+ def __init__(self, hs, transport_layer):
+ self.server_name = hs.hostname
+
+ self.transport_layer = transport_layer
+ self.transport_layer.register_received_handler(self)
+ self.transport_layer.register_request_handler(self)
+
+ self.store = hs.get_datastore()
+ self.pdu_actions = PduActions(self.store)
+ self.transaction_actions = TransactionActions(self.store)
+
+ self._transaction_queue = _TransactionQueue(
+ hs, self.transaction_actions, transport_layer
+ )
+
+ self.handler = None
+ self.edu_handlers = {}
+
+ self._order = 0
+
+ self._clock = hs.get_clock()
+
+ def set_handler(self, handler):
+ """Sets the handler that the replication layer will use to communicate
+ receipt of new PDUs from other home servers. The required methods are
+ documented on :py:class:`.ReplicationHandler`.
+ """
+ self.handler = handler
+
+ def register_edu_handler(self, edu_type, handler):
+ if edu_type in self.edu_handlers:
+ raise KeyError("Already have an EDU handler for %s" % (edu_type))
+
+ self.edu_handlers[edu_type] = handler
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_pdu(self, pdu):
+ """Informs the replication layer about a new PDU generated within the
+ home server that should be transmitted to others.
+
+ This will fill out various attributes on the PDU object, e.g. the
+ `prev_pdus` key.
+
+ *Note:* The home server should always call `send_pdu` even if it knows
+ that it does not need to be replicated to other home servers. This is
+ in case e.g. someone else joins via a remote home server and then
+ paginates.
+
+ TODO: Figure out when we should actually resolve the deferred.
+
+ Args:
+ pdu (Pdu): The new Pdu.
+
+ Returns:
+ Deferred: Completes when we have successfully processed the PDU
+ and replicated it to any interested remote home servers.
+ """
+ order = self._order
+ self._order += 1
+
+ logger.debug("[%s] Persisting PDU", pdu.pdu_id)
+
+ #yield self.pdu_actions.populate_previous_pdus(pdu)
+
+ # Save *before* trying to send
+ yield self.pdu_actions.persist_outgoing(pdu)
+
+ logger.debug("[%s] Persisted PDU", pdu.pdu_id)
+ logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
+
+ # TODO, add errback, etc.
+ self._transaction_queue.enqueue_pdu(pdu, order)
+
+ logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
+
+ @log_function
+ def send_edu(self, destination, edu_type, content):
+ edu = Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type=edu_type,
+ content=content,
+ )
+
+ # TODO, add errback, etc.
+ self._transaction_queue.enqueue_edu(edu)
+
+ @defer.inlineCallbacks
+ @log_function
+ def paginate(self, dest, context, limit):
+ """Requests some more historic PDUs for the given context from the
+ given destination server.
+
+ Args:
+ dest (str): The remote home server to ask.
+ context (str): The context to paginate back on.
+ limit (int): The maximum number of PDUs to return.
+
+ Returns:
+ Deferred: Results in the received PDUs.
+ """
+ extremities = yield self.store.get_oldest_pdus_in_context(context)
+
+ logger.debug("paginate extrem=%s", extremities)
+
+ # If there are no extremeties then we've (probably) reached the start.
+ if not extremities:
+ return
+
+ transaction_data = yield self.transport_layer.paginate(
+ dest, context, extremities, limit)
+
+ logger.debug("paginate transaction_data=%s", repr(transaction_data))
+
+ transaction = Transaction(**transaction_data)
+
+ pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
+ for pdu in pdus:
+ yield self._handle_new_pdu(pdu)
+
+ defer.returnValue(pdus)
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
+ """Requests the PDU with given origin and ID from the remote home
+ server.
+
+ This will persist the PDU locally upon receipt.
+
+ Args:
+ destination (str): Which home server to query
+ pdu_origin (str): The home server that originally sent the pdu.
+ pdu_id (str)
+ outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
+ it's from an arbitary point in the context as opposed to part
+ of the current block of PDUs. Defaults to `False`
+
+ Returns:
+ Deferred: Results in the requested PDU.
+ """
+
+ transaction_data = yield self.transport_layer.get_pdu(
+ destination, pdu_origin, pdu_id)
+
+ transaction = Transaction(**transaction_data)
+
+ pdu_list = [Pdu(outlier=outlier, **p) for p in transaction.pdus]
+
+ pdu = None
+ if pdu_list:
+ pdu = pdu_list[0]
+ yield self._handle_new_pdu(pdu)
+
+ defer.returnValue(pdu)
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_state_for_context(self, destination, context):
+ """Requests all of the `current` state PDUs for a given context from
+ a remote home server.
+
+ Args:
+ destination (str): The remote homeserver to query for the state.
+ context (str): The context we're interested in.
+
+ Returns:
+ Deferred: Results in a list of PDUs.
+ """
+
+ transaction_data = yield self.transport_layer.get_context_state(
+ destination, context)
+
+ transaction = Transaction(**transaction_data)
+
+ pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
+ for pdu in pdus:
+ yield self._handle_new_pdu(pdu)
+
+ defer.returnValue(pdus)
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_context_pdus_request(self, context):
+ pdus = yield self.pdu_actions.get_all_pdus_from_context(
+ context
+ )
+ defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_paginate_request(self, context, versions, limit):
+
+ pdus = yield self.pdu_actions.paginate(context, versions, limit)
+
+ defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_incoming_transaction(self, transaction_data):
+ transaction = Transaction(**transaction_data)
+
+ logger.debug("[%s] Got transaction", transaction.transaction_id)
+
+ response = yield self.transaction_actions.have_responded(transaction)
+
+ if response:
+ logger.debug("[%s] We've already responed to this request",
+ transaction.transaction_id)
+ defer.returnValue(response)
+ return
+
+ logger.debug("[%s] Transacition is new", transaction.transaction_id)
+
+ pdu_list = [Pdu(**p) for p in transaction.pdus]
+
+ dl = []
+ for pdu in pdu_list:
+ dl.append(self._handle_new_pdu(pdu))
+
+ if hasattr(transaction, "edus"):
+ for edu in [Edu(**x) for x in transaction.edus]:
+ self.received_edu(edu.origin, edu.edu_type, edu.content)
+
+ results = yield defer.DeferredList(dl)
+
+ ret = []
+ for r in results:
+ if r[0]:
+ ret.append({})
+ else:
+ logger.exception(r[1])
+ ret.append({"error": str(r[1])})
+
+ logger.debug("Returning: %s", str(ret))
+
+ yield self.transaction_actions.set_response(
+ transaction,
+ 200, response
+ )
+ defer.returnValue((200, response))
+
+ def received_edu(self, origin, edu_type, content):
+ if edu_type in self.edu_handlers:
+ self.edu_handlers[edu_type](origin, content)
+ else:
+ logger.warn("Received EDU of type %s with no handler", edu_type)
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_context_state_request(self, context):
+ results = yield self.store.get_current_state_for_context(
+ context
+ )
+
+ logger.debug("Context returning %d results", len(results))
+
+ pdus = [Pdu.from_pdu_tuple(p) for p in results]
+ defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_pdu_request(self, pdu_origin, pdu_id):
+ pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
+
+ if pdu:
+ defer.returnValue(
+ (200, self._transaction_from_pdus([pdu]).get_dict())
+ )
+ else:
+ defer.returnValue((404, ""))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_pull_request(self, origin, versions):
+ transaction_id = max([int(v) for v in versions])
+
+ response = yield self.pdu_actions.after_transaction(
+ transaction_id,
+ origin,
+ self.server_name
+ )
+
+ if not response:
+ response = []
+
+ defer.returnValue(
+ (200, self._transaction_from_pdus(response).get_dict())
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _get_persisted_pdu(self, pdu_id, pdu_origin):
+ """ Get a PDU from the database with given origin and id.
+
+ Returns:
+ Deferred: Results in a `Pdu`.
+ """
+ pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
+
+ defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
+
+ def _transaction_from_pdus(self, pdu_list):
+ """Returns a new Transaction containing the given PDUs suitable for
+ transmission.
+ """
+ return Transaction(
+ pdus=[p.get_dict() for p in pdu_list],
+ origin=self.server_name,
+ ts=int(self._clock.time_msec()),
+ destination=None,
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _handle_new_pdu(self, pdu):
+ # We reprocess pdus when we have seen them only as outliers
+ existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
+
+ if existing and (not existing.outlier or pdu.outlier):
+ logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
+ defer.returnValue({})
+ return
+
+ # Get missing pdus if necessary.
+ is_new = yield self.pdu_actions.is_new(pdu)
+ if is_new and not pdu.outlier:
+ # We only paginate backwards to the min depth.
+ min_depth = yield self.store.get_min_depth_for_context(pdu.context)
+
+ if min_depth and pdu.depth > min_depth:
+ for pdu_id, origin in pdu.prev_pdus:
+ exists = yield self._get_persisted_pdu(pdu_id, origin)
+
+ if not exists:
+ logger.debug("Requesting pdu %s %s", pdu_id, origin)
+
+ try:
+ yield self.get_pdu(
+ pdu.origin,
+ pdu_id=pdu_id,
+ pdu_origin=origin
+ )
+ logger.debug("Processed pdu %s %s", pdu_id, origin)
+ except:
+ # TODO(erikj): Do some more intelligent retries.
+ logger.exception("Failed to get PDU")
+
+ # Persist the Pdu, but don't mark it as processed yet.
+ yield self.pdu_actions.persist_received(pdu)
+
+ ret = yield self.handler.on_receive_pdu(pdu)
+
+ yield self.pdu_actions.mark_as_processed(pdu)
+
+ defer.returnValue(ret)
+
+ def __str__(self):
+ return "<ReplicationLayer(%s)>" % self.server_name
+
+
+class ReplicationHandler(object):
+ """This defines the methods that the :py:class:`.ReplicationLayer` will
+ use to communicate with the rest of the home server.
+ """
+ def on_receive_pdu(self, pdu):
+ raise NotImplementedError("on_receive_pdu")
+
+
+class _TransactionQueue(object):
+ """This class makes sure we only have one transaction in flight at
+ a time for a given destination.
+
+ It batches pending PDUs into single transactions.
+ """
+
+ def __init__(self, hs, transaction_actions, transport_layer):
+
+ self.server_name = hs.hostname
+ self.transaction_actions = transaction_actions
+ self.transport_layer = transport_layer
+
+ self._clock = hs.get_clock()
+
+ # Is a mapping from destinations -> deferreds. Used to keep track
+ # of which destinations have transactions in flight and when they are
+ # done
+ self.pending_transactions = {}
+
+ # Is a mapping from destination -> list of
+ # tuple(pending pdus, deferred, order)
+ self.pending_pdus_by_dest = {}
+ # destination -> list of tuple(edu, deferred)
+ self.pending_edus_by_dest = {}
+
+ # HACK to get unique tx id
+ self._next_txn_id = int(self._clock.time_msec())
+
+ @defer.inlineCallbacks
+ @log_function
+ def enqueue_pdu(self, pdu, order):
+ # We loop through all destinations to see whether we already have
+ # a transaction in progress. If we do, stick it in the pending_pdus
+ # table and we'll get back to it later.
+
+ destinations = [
+ d for d in pdu.destinations
+ if d != self.server_name
+ ]
+
+ logger.debug("Sending to: %s", str(destinations))
+
+ if not destinations:
+ return
+
+ deferreds = []
+
+ for destination in destinations:
+ deferred = defer.Deferred()
+ self.pending_pdus_by_dest.setdefault(destination, []).append(
+ (pdu, deferred, order)
+ )
+
+ self._attempt_new_transaction(destination)
+
+ deferreds.append(deferred)
+
+ yield defer.DeferredList(deferreds)
+
+ # NO inlineCallbacks
+ def enqueue_edu(self, edu):
+ destination = edu.destination
+
+ deferred = defer.Deferred()
+ self.pending_edus_by_dest.setdefault(destination, []).append(
+ (edu, deferred)
+ )
+
+ def eb(failure):
+ deferred.errback(failure)
+ self._attempt_new_transaction(destination).addErrback(eb)
+
+ return deferred
+
+ @defer.inlineCallbacks
+ @log_function
+ def _attempt_new_transaction(self, destination):
+ if destination in self.pending_transactions:
+ return
+
+ # list of (pending_pdu, deferred, order)
+ pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+ pending_edus = self.pending_edus_by_dest.pop(destination, [])
+
+ if not pending_pdus and not pending_edus:
+ return
+
+ logger.debug("TX [%s] Attempting new transaction", destination)
+
+ # Sort based on the order field
+ pending_pdus.sort(key=lambda t: t[2])
+
+ pdus = [x[0] for x in pending_pdus]
+ edus = [x[0] for x in pending_edus]
+ deferreds = [x[1] for x in pending_pdus + pending_edus]
+
+ try:
+ self.pending_transactions[destination] = 1
+
+ logger.debug("TX [%s] Persisting transaction...", destination)
+
+ transaction = Transaction.create_new(
+ ts=self._clock.time_msec(),
+ transaction_id=self._next_txn_id,
+ origin=self.server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ )
+
+ self._next_txn_id += 1
+
+ yield self.transaction_actions.prepare_to_send(transaction)
+
+ logger.debug("TX [%s] Persisted transaction", destination)
+ logger.debug("TX [%s] Sending transaction...", destination)
+
+ # Actually send the transaction
+ code, response = yield self.transport_layer.send_transaction(
+ transaction
+ )
+
+ logger.debug("TX [%s] Sent transaction", destination)
+ logger.debug("TX [%s] Marking as delivered...", destination)
+
+ yield self.transaction_actions.delivered(
+ transaction, code, response
+ )
+
+ logger.debug("TX [%s] Marked as delivered", destination)
+ logger.debug("TX [%s] Yielding to callbacks...", destination)
+
+ for deferred in deferreds:
+ if code == 200:
+ deferred.callback(None)
+ else:
+ deferred.errback(RuntimeError("Got status %d" % code))
+
+ # Ensures we don't continue until all callbacks on that
+ # deferred have fired
+ yield deferred
+
+ logger.debug("TX [%s] Yielded to callbacks", destination)
+
+ except Exception as e:
+ logger.error("TX Problem in _attempt_transaction")
+
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.exception(e)
+
+ for deferred in deferreds:
+ deferred.errback(e)
+ yield deferred
+
+ finally:
+ # We want to be *very* sure we delete this after we stop processing
+ self.pending_transactions.pop(destination, None)
+
+ # Check to see if there is anything else to send.
+ self._attempt_new_transaction(destination)
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
new file mode 100644
index 0000000000..2136adf8d7
--- /dev/null
+++ b/synapse/federation/transport.py
@@ -0,0 +1,454 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The transport layer is responsible for both sending transactions to remote
+home servers and receiving a variety of requests from other home servers.
+
+Typically, this is done over HTTP (and all home servers are required to
+support HTTP), however individual pairings of servers may decide to communicate
+over a different (albeit still reliable) protocol.
+"""
+
+from twisted.internet import defer
+
+from synapse.util.logutils import log_function
+
+import logging
+import json
+import re
+
+
+logger = logging.getLogger(__name__)
+
+
+class TransportLayer(object):
+ """This is a basic implementation of the transport layer that translates
+ transactions and other requests to/from HTTP.
+
+ Attributes:
+ server_name (str): Local home server host
+
+ server (synapse.http.server.HttpServer): the http server to
+ register listeners on
+
+ client (synapse.http.client.HttpClient): the http client used to
+ send requests
+
+ request_handler (TransportRequestHandler): The handler to fire when we
+ receive requests for data.
+
+ received_handler (TransportReceivedHandler): The handler to fire when
+ we receive data.
+ """
+
+ def __init__(self, server_name, server, client):
+ """
+ Args:
+ server_name (str): Local home server host
+ server (synapse.protocol.http.HttpServer): the http server to
+ register listeners on
+ client (synapse.protocol.http.HttpClient): the http client used to
+ send requests
+ """
+ self.server_name = server_name
+ self.server = server
+ self.client = client
+ self.request_handler = None
+ self.received_handler = None
+
+ @log_function
+ def get_context_state(self, destination, context):
+ """ Requests all state for a given context (i.e. room) from the
+ given server.
+
+ Args:
+ destination (str): The host name of the remote home server we want
+ to get the state from.
+ context (str): The name of the context we want the state of
+
+ Returns:
+ Deferred: Results in a dict received from the remote homeserver.
+ """
+ logger.debug("get_context_state dest=%s, context=%s",
+ destination, context)
+
+ path = "/state/%s/" % context
+
+ return self._do_request_for_transaction(destination, path)
+
+ @log_function
+ def get_pdu(self, destination, pdu_origin, pdu_id):
+ """ Requests the pdu with give id and origin from the given server.
+
+ Args:
+ destination (str): The host name of the remote home server we want
+ to get the state from.
+ pdu_origin (str): The home server which created the PDU.
+ pdu_id (str): The id of the PDU being requested.
+
+ Returns:
+ Deferred: Results in a dict received from the remote homeserver.
+ """
+ logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
+ destination, pdu_origin, pdu_id)
+
+ path = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
+
+ return self._do_request_for_transaction(destination, path)
+
+ @log_function
+ def paginate(self, dest, context, pdu_tuples, limit):
+ """ Requests `limit` previous PDUs in a given context before list of
+ PDUs.
+
+ Args:
+ dest (str)
+ context (str)
+ pdu_tuples (list)
+ limt (int)
+
+ Returns:
+ Deferred: Results in a dict received from the remote homeserver.
+ """
+ logger.debug(
+ "paginate dest=%s, context=%s, pdu_tuples=%s, limit=%s",
+ dest, context, repr(pdu_tuples), str(limit)
+ )
+
+ if not pdu_tuples:
+ return
+
+ path = "/paginate/%s/" % context
+
+ args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
+ args["limit"] = limit
+
+ return self._do_request_for_transaction(
+ dest,
+ path,
+ args=args,
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_transaction(self, transaction):
+ """ Sends the given Transaction to it's destination
+
+ Args:
+ transaction (Transaction)
+
+ Returns:
+ Deferred: Results of the deferred is a tuple in the form of
+ (response_code, response_body) where the response_body is a
+ python dict decoded from json
+ """
+ logger.debug(
+ "send_data dest=%s, txid=%s",
+ transaction.destination, transaction.transaction_id
+ )
+
+ if transaction.destination == self.server_name:
+ raise RuntimeError("Transport layer cannot send to itself!")
+
+ data = transaction.get_dict()
+
+ code, response = yield self.client.put_json(
+ transaction.destination,
+ path="/send/%s/" % transaction.transaction_id,
+ data=data
+ )
+
+ logger.debug(
+ "send_data dest=%s, txid=%s, got response: %d",
+ transaction.destination, transaction.transaction_id, code
+ )
+
+ defer.returnValue((code, response))
+
+ @log_function
+ def register_received_handler(self, handler):
+ """ Register a handler that will be fired when we receive data.
+
+ Args:
+ handler (TransportReceivedHandler)
+ """
+ self.received_handler = handler
+
+ # This is when someone is trying to send us a bunch of data.
+ self.server.register_path(
+ "PUT",
+ re.compile("^/send/([^/]*)/$"),
+ self._on_send_request
+ )
+
+ @log_function
+ def register_request_handler(self, handler):
+ """ Register a handler that will be fired when we get asked for data.
+
+ Args:
+ handler (TransportRequestHandler)
+ """
+ self.request_handler = handler
+
+ # TODO(markjh): Namespace the federation URI paths
+
+ # This is for when someone asks us for everything since version X
+ self.server.register_path(
+ "GET",
+ re.compile("^/pull/$"),
+ lambda request: handler.on_pull_request(
+ request.args["origin"][0],
+ request.args["v"]
+ )
+ )
+
+ # This is when someone asks for a data item for a given server
+ # data_id pair.
+ self.server.register_path(
+ "GET",
+ re.compile("^/pdu/([^/]*)/([^/]*)/$"),
+ lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
+ pdu_origin, pdu_id
+ )
+ )
+
+ # This is when someone asks for all data for a given context.
+ self.server.register_path(
+ "GET",
+ re.compile("^/state/([^/]*)/$"),
+ lambda request, context: handler.on_context_state_request(
+ context
+ )
+ )
+
+ self.server.register_path(
+ "GET",
+ re.compile("^/paginate/([^/]*)/$"),
+ lambda request, context: self._on_paginate_request(
+ context, request.args["v"],
+ request.args["limit"]
+ )
+ )
+
+ self.server.register_path(
+ "GET",
+ re.compile("^/context/([^/]*)/$"),
+ lambda request, context: handler.on_context_pdus_request(context)
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _on_send_request(self, request, transaction_id):
+ """ Called on PUT /send/<transaction_id>/
+
+ Args:
+ request (twisted.web.http.Request): The HTTP request.
+ transaction_id (str): The transaction_id associated with this
+ request. This is *not* None.
+
+ Returns:
+ Deferred: Results in a tuple of `(code, response)`, where
+ `response` is a python dict to be converted into JSON that is
+ used as the response body.
+ """
+ # Parse the request
+ try:
+ data = request.content.read()
+
+ l = data[:20].encode("string_escape")
+ logger.debug("Got data: \"%s\"", l)
+
+ transaction_data = json.loads(data)
+
+ logger.debug(
+ "Decoded %s: %s",
+ transaction_id, str(transaction_data)
+ )
+
+ # We should ideally be getting this from the security layer.
+ # origin = body["origin"]
+
+ # Add some extra data to the transaction dict that isn't included
+ # in the request body.
+ transaction_data.update(
+ transaction_id=transaction_id,
+ destination=self.server_name
+ )
+
+ except Exception as e:
+ logger.exception(e)
+ defer.returnValue((400, {"error": "Invalid transaction"}))
+ return
+
+ code, response = yield self.received_handler.on_incoming_transaction(
+ transaction_data
+ )
+
+ defer.returnValue((code, response))
+
+ @defer.inlineCallbacks
+ @log_function
+ def _do_request_for_transaction(self, destination, path, args={}):
+ """
+ Args:
+ destination (str)
+ path (str)
+ args (dict): This is parsed directly to the HttpClient.
+
+ Returns:
+ Deferred: Results in a dict.
+ """
+
+ data = yield self.client.get_json(
+ destination,
+ path=path,
+ args=args,
+ )
+
+ # Add certain keys to the JSON, ready for decoding as a Transaction
+ data.update(
+ origin=destination,
+ destination=self.server_name,
+ transaction_id=None
+ )
+
+ defer.returnValue(data)
+
+ @log_function
+ def _on_paginate_request(self, context, v_list, limits):
+ if not limits:
+ return defer.succeed(
+ (400, {"error": "Did not include limit param"})
+ )
+
+ limit = int(limits[-1])
+
+ versions = [v.split(",", 1) for v in v_list]
+
+ return self.request_handler.on_paginate_request(
+ context, versions, limit)
+
+
+class TransportReceivedHandler(object):
+ """ Callbacks used when we receive a transaction
+ """
+ def on_incoming_transaction(self, transaction):
+ """ Called on PUT /send/<transaction_id>, or on response to a request
+ that we sent (e.g. a pagination request)
+
+ Args:
+ transaction (synapse.transaction.Transaction): The transaction that
+ was sent to us.
+
+ Returns:
+ twisted.internet.defer.Deferred: A deferred that get's fired when
+ the transaction has finished being processed.
+
+ The result should be a tuple in the form of
+ `(response_code, respond_body)`, where `response_body` is a python
+ dict that will get serialized to JSON.
+
+ On errors, the dict should have an `error` key with a brief message
+ of what went wrong.
+ """
+ pass
+
+
+class TransportRequestHandler(object):
+ """ Handlers used when someone want's data from us
+ """
+ def on_pull_request(self, versions):
+ """ Called on GET /pull/?v=...
+
+ This is hit when a remote home server wants to get all data
+ after a given transaction. Mainly used when a home server comes back
+ online and wants to get everything it has missed.
+
+ Args:
+ versions (list): A list of transaction_ids that should be used to
+ determine what PDUs the remote side have not yet seen.
+
+ Returns:
+ Deferred: Resultsin a tuple in the form of
+ `(response_code, respond_body)`, where `response_body` is a python
+ dict that will get serialized to JSON.
+
+ On errors, the dict should have an `error` key with a brief message
+ of what went wrong.
+ """
+ pass
+
+ def on_pdu_request(self, pdu_origin, pdu_id):
+ """ Called on GET /pdu/<pdu_origin>/<pdu_id>/
+
+ Someone wants a particular PDU. This PDU may or may not have originated
+ from us.
+
+ Args:
+ pdu_origin (str)
+ pdu_id (str)
+
+ Returns:
+ Deferred: Resultsin a tuple in the form of
+ `(response_code, respond_body)`, where `response_body` is a python
+ dict that will get serialized to JSON.
+
+ On errors, the dict should have an `error` key with a brief message
+ of what went wrong.
+ """
+ pass
+
+ def on_context_state_request(self, context):
+ """ Called on GET /state/<context>/
+
+ Get's hit when someone wants all the *current* state for a given
+ contexts.
+
+ Args:
+ context (str): The name of the context that we're interested in.
+
+ Returns:
+ twisted.internet.defer.Deferred: A deferred that get's fired when
+ the transaction has finished being processed.
+
+ The result should be a tuple in the form of
+ `(response_code, respond_body)`, where `response_body` is a python
+ dict that will get serialized to JSON.
+
+ On errors, the dict should have an `error` key with a brief message
+ of what went wrong.
+ """
+ pass
+
+ def on_paginate_request(self, context, versions, limit):
+ """ Called on GET /paginate/<context>/?v=...&limit=...
+
+ Get's hit when we want to paginate backwards on a given context from
+ the given point.
+
+ Args:
+ context (str): The context to paginate on
+ versions (list): A list of 2-tuple's representing where to paginate
+ from, in the form `(pdu_id, origin)`
+ limit (int): How many pdus to return.
+
+ Returns:
+ Deferred: Resultsin a tuple in the form of
+ `(response_code, respond_body)`, where `response_body` is a python
+ dict that will get serialized to JSON.
+
+ On errors, the dict should have an `error` key with a brief message
+ of what went wrong.
+ """
+ pass
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
new file mode 100644
index 0000000000..0efea7b768
--- /dev/null
+++ b/synapse/federation/units.py
@@ -0,0 +1,236 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Defines the JSON structure of the protocol units used by the server to
+server protocol.
+"""
+
+from synapse.util.jsonobject import JsonEncodedObject
+
+import logging
+import json
+import copy
+
+
+logger = logging.getLogger(__name__)
+
+
+class Pdu(JsonEncodedObject):
+ """ A Pdu represents a piece of data sent from a server and is associated
+ with a context.
+
+ A Pdu can be classified as "state". For a given context, we can efficiently
+ retrieve all state pdu's that haven't been clobbered. Clobbering is done
+ via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
+ is a state pdu if `is_state` is True.
+
+ Example pdu::
+
+ {
+ "pdu_id": "78c",
+ "ts": 1404835423000,
+ "origin": "bar",
+ "prev_ids": [
+ ["23b", "foo"],
+ ["56a", "bar"],
+ ],
+ "content": { ... },
+ }
+
+ """
+
+ valid_keys = [
+ "pdu_id",
+ "context",
+ "origin",
+ "ts",
+ "pdu_type",
+ "destinations",
+ "transaction_id",
+ "prev_pdus",
+ "depth",
+ "content",
+ "outlier",
+ "is_state", # Below this are keys valid only for State Pdus.
+ "state_key",
+ "power_level",
+ "prev_state_id",
+ "prev_state_origin",
+ ]
+
+ internal_keys = [
+ "destinations",
+ "transaction_id",
+ "outlier",
+ ]
+
+ required_keys = [
+ "pdu_id",
+ "context",
+ "origin",
+ "ts",
+ "pdu_type",
+ "content",
+ ]
+
+ # TODO: We need to make this properly load content rather than
+ # just leaving it as a dict. (OR DO WE?!)
+
+ def __init__(self, destinations=[], is_state=False, prev_pdus=[],
+ outlier=False, **kwargs):
+ if is_state:
+ for required_key in ["state_key"]:
+ if required_key not in kwargs:
+ raise RuntimeError("Key %s is required" % required_key)
+
+ super(Pdu, self).__init__(
+ destinations=destinations,
+ is_state=is_state,
+ prev_pdus=prev_pdus,
+ outlier=outlier,
+ **kwargs
+ )
+
+ @classmethod
+ def from_pdu_tuple(cls, pdu_tuple):
+ """ Converts a PduTuple to a Pdu
+
+ Args:
+ pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
+ convert
+
+ Returns:
+ Pdu
+ """
+ if pdu_tuple:
+ d = copy.copy(pdu_tuple.pdu_entry._asdict())
+
+ d["content"] = json.loads(d["content_json"])
+ del d["content_json"]
+
+ args = {f: d[f] for f in cls.valid_keys if f in d}
+ if "unrecognized_keys" in d and d["unrecognized_keys"]:
+ args.update(json.loads(d["unrecognized_keys"]))
+
+ return Pdu(
+ prev_pdus=pdu_tuple.prev_pdu_list,
+ **args
+ )
+ else:
+ return None
+
+ def __str__(self):
+ return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
+
+ def __repr__(self):
+ return "<%s, %s>" % (self.__class__.__name__, repr(self.__dict__))
+
+
+class Edu(JsonEncodedObject):
+ """ An Edu represents a piece of data sent from one homeserver to another.
+
+ In comparison to Pdus, Edus are not persisted for a long time on disk, are
+ not meaningful beyond a given pair of homeservers, and don't have an
+ internal ID or previous references graph.
+ """
+
+ valid_keys = [
+ "origin",
+ "destination",
+ "edu_type",
+ "content",
+ ]
+
+ required_keys = [
+ "origin",
+ "destination",
+ "edu_type",
+ ]
+
+
+class Transaction(JsonEncodedObject):
+ """ A transaction is a list of Pdus and Edus to be sent to a remote home
+ server with some extra metadata.
+
+ Example transaction::
+
+ {
+ "origin": "foo",
+ "prev_ids": ["abc", "def"],
+ "pdus": [
+ ...
+ ],
+ }
+
+ """
+
+ valid_keys = [
+ "transaction_id",
+ "origin",
+ "destination",
+ "ts",
+ "previous_ids",
+ "pdus",
+ "edus",
+ ]
+
+ internal_keys = [
+ "transaction_id",
+ "destination",
+ ]
+
+ required_keys = [
+ "transaction_id",
+ "origin",
+ "destination",
+ "ts",
+ "pdus",
+ ]
+
+ def __init__(self, transaction_id=None, pdus=[], **kwargs):
+ """ If we include a list of pdus then we decode then as PDU's
+ automatically.
+ """
+
+ # If there's no EDUs then remove the arg
+ if "edus" in kwargs and not kwargs["edus"]:
+ del kwargs["edus"]
+
+ super(Transaction, self).__init__(
+ transaction_id=transaction_id,
+ pdus=pdus,
+ **kwargs
+ )
+
+ @staticmethod
+ def create_new(pdus, **kwargs):
+ """ Used to create a new transaction. Will auto fill out
+ transaction_id and ts keys.
+ """
+ if "ts" not in kwargs:
+ raise KeyError("Require 'ts' to construct a Transaction")
+ if "transaction_id" not in kwargs:
+ raise KeyError(
+ "Require 'transaction_id' to construct a Transaction"
+ )
+
+ for p in pdus:
+ p.transaction_id = kwargs["transaction_id"]
+
+ kwargs["pdus"] = [p.get_dict() for p in pdus]
+
+ return Transaction(**kwargs)
+
+
+
|