summary refs log tree commit diff
path: root/synapse/federation/replication.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/replication.py')
-rw-r--r--synapse/federation/replication.py582
1 files changed, 582 insertions, 0 deletions
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
new file mode 100644
index 0000000000..0f5b974291
--- /dev/null
+++ b/synapse/federation/replication.py
@@ -0,0 +1,582 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This layer is responsible for replicating with remote home servers using
+a given transport.
+"""
+
+from twisted.internet import defer
+
+from .units import Transaction, Pdu, Edu
+
+from .persistence import PduActions, TransactionActions
+
+from synapse.util.logutils import log_function
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationLayer(object):
+    """This layer is responsible for replicating with remote home servers over
+    the given transport. I.e., does the sending and receiving of PDUs to
+    remote home servers.
+
+    The layer communicates with the rest of the server via a registered
+    ReplicationHandler.
+
+    In more detail, the layer:
+        * Receives incoming data and processes it into transactions and pdus.
+        * Fetches any PDUs it thinks it might have missed.
+        * Keeps the current state for contexts up to date by applying the
+          suitable conflict resolution.
+        * Sends outgoing pdus wrapped in transactions.
+        * Fills out the references to previous pdus/transactions appropriately
+          for outgoing data.
+    """
+
+    def __init__(self, hs, transport_layer):
+        self.server_name = hs.hostname
+
+        self.transport_layer = transport_layer
+        self.transport_layer.register_received_handler(self)
+        self.transport_layer.register_request_handler(self)
+
+        self.store = hs.get_datastore()
+        self.pdu_actions = PduActions(self.store)
+        self.transaction_actions = TransactionActions(self.store)
+
+        self._transaction_queue = _TransactionQueue(
+            hs, self.transaction_actions, transport_layer
+        )
+
+        self.handler = None
+        self.edu_handlers = {}
+
+        self._order = 0
+
+        self._clock = hs.get_clock()
+
+    def set_handler(self, handler):
+        """Sets the handler that the replication layer will use to communicate
+        receipt of new PDUs from other home servers. The required methods are
+        documented on :py:class:`.ReplicationHandler`.
+        """
+        self.handler = handler
+
+    def register_edu_handler(self, edu_type, handler):
+        if edu_type in self.edu_handlers:
+            raise KeyError("Already have an EDU handler for %s" % (edu_type))
+
+        self.edu_handlers[edu_type] = handler
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_pdu(self, pdu):
+        """Informs the replication layer about a new PDU generated within the
+        home server that should be transmitted to others.
+
+        This will fill out various attributes on the PDU object, e.g. the
+        `prev_pdus` key.
+
+        *Note:* The home server should always call `send_pdu` even if it knows
+        that it does not need to be replicated to other home servers. This is
+        in case e.g. someone else joins via a remote home server and then
+        paginates.
+
+        TODO: Figure out when we should actually resolve the deferred.
+
+        Args:
+            pdu (Pdu): The new Pdu.
+
+        Returns:
+            Deferred: Completes when we have successfully processed the PDU
+            and replicated it to any interested remote home servers.
+        """
+        order = self._order
+        self._order += 1
+
+        logger.debug("[%s] Persisting PDU", pdu.pdu_id)
+
+        #yield self.pdu_actions.populate_previous_pdus(pdu)
+
+        # Save *before* trying to send
+        yield self.pdu_actions.persist_outgoing(pdu)
+
+        logger.debug("[%s] Persisted PDU", pdu.pdu_id)
+        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
+
+        # TODO, add errback, etc.
+        self._transaction_queue.enqueue_pdu(pdu, order)
+
+        logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
+
+    @log_function
+    def send_edu(self, destination, edu_type, content):
+        edu = Edu(
+            origin=self.server_name,
+            destination=destination,
+            edu_type=edu_type,
+            content=content,
+        )
+
+        # TODO, add errback, etc.
+        self._transaction_queue.enqueue_edu(edu)
+
+    @defer.inlineCallbacks
+    @log_function
+    def paginate(self, dest, context, limit):
+        """Requests some more historic PDUs for the given context from the
+        given destination server.
+
+        Args:
+            dest (str): The remote home server to ask.
+            context (str): The context to paginate back on.
+            limit (int): The maximum number of PDUs to return.
+
+        Returns:
+            Deferred: Results in the received PDUs.
+        """
+        extremities = yield self.store.get_oldest_pdus_in_context(context)
+
+        logger.debug("paginate extrem=%s", extremities)
+
+        # If there are no extremeties then we've (probably) reached the start.
+        if not extremities:
+            return
+
+        transaction_data = yield self.transport_layer.paginate(
+            dest, context, extremities, limit)
+
+        logger.debug("paginate transaction_data=%s", repr(transaction_data))
+
+        transaction = Transaction(**transaction_data)
+
+        pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
+        for pdu in pdus:
+            yield self._handle_new_pdu(pdu)
+
+        defer.returnValue(pdus)
+
+    @defer.inlineCallbacks
+    @log_function
+    def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
+        """Requests the PDU with given origin and ID from the remote home
+        server.
+
+        This will persist the PDU locally upon receipt.
+
+        Args:
+            destination (str): Which home server to query
+            pdu_origin (str): The home server that originally sent the pdu.
+            pdu_id (str)
+            outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
+                it's from an arbitary point in the context as opposed to part
+                of the current block of PDUs. Defaults to `False`
+
+        Returns:
+            Deferred: Results in the requested PDU.
+        """
+
+        transaction_data = yield self.transport_layer.get_pdu(
+            destination, pdu_origin, pdu_id)
+
+        transaction = Transaction(**transaction_data)
+
+        pdu_list = [Pdu(outlier=outlier, **p) for p in transaction.pdus]
+
+        pdu = None
+        if pdu_list:
+            pdu = pdu_list[0]
+            yield self._handle_new_pdu(pdu)
+
+        defer.returnValue(pdu)
+
+    @defer.inlineCallbacks
+    @log_function
+    def get_state_for_context(self, destination, context):
+        """Requests all of the `current` state PDUs for a given context from
+        a remote home server.
+
+        Args:
+            destination (str): The remote homeserver to query for the state.
+            context (str): The context we're interested in.
+
+        Returns:
+            Deferred: Results in a list of PDUs.
+        """
+
+        transaction_data = yield self.transport_layer.get_context_state(
+            destination, context)
+
+        transaction = Transaction(**transaction_data)
+
+        pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
+        for pdu in pdus:
+            yield self._handle_new_pdu(pdu)
+
+        defer.returnValue(pdus)
+
+    @defer.inlineCallbacks
+    @log_function
+    def on_context_pdus_request(self, context):
+        pdus = yield self.pdu_actions.get_all_pdus_from_context(
+            context
+        )
+        defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+
+    @defer.inlineCallbacks
+    @log_function
+    def on_paginate_request(self, context, versions, limit):
+
+        pdus = yield self.pdu_actions.paginate(context, versions, limit)
+
+        defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+
+    @defer.inlineCallbacks
+    @log_function
+    def on_incoming_transaction(self, transaction_data):
+        transaction = Transaction(**transaction_data)
+
+        logger.debug("[%s] Got transaction", transaction.transaction_id)
+
+        response = yield self.transaction_actions.have_responded(transaction)
+
+        if response:
+            logger.debug("[%s] We've already responed to this request",
+                         transaction.transaction_id)
+            defer.returnValue(response)
+            return
+
+        logger.debug("[%s] Transacition is new", transaction.transaction_id)
+
+        pdu_list = [Pdu(**p) for p in transaction.pdus]
+
+        dl = []
+        for pdu in pdu_list:
+            dl.append(self._handle_new_pdu(pdu))
+
+        if hasattr(transaction, "edus"):
+            for edu in [Edu(**x) for x in transaction.edus]:
+                self.received_edu(edu.origin, edu.edu_type, edu.content)
+
+        results = yield defer.DeferredList(dl)
+
+        ret = []
+        for r in results:
+            if r[0]:
+                ret.append({})
+            else:
+                logger.exception(r[1])
+                ret.append({"error": str(r[1])})
+
+        logger.debug("Returning: %s", str(ret))
+
+        yield self.transaction_actions.set_response(
+            transaction,
+            200, response
+        )
+        defer.returnValue((200, response))
+
+    def received_edu(self, origin, edu_type, content):
+        if edu_type in self.edu_handlers:
+            self.edu_handlers[edu_type](origin, content)
+        else:
+            logger.warn("Received EDU of type %s with no handler", edu_type)
+
+    @defer.inlineCallbacks
+    @log_function
+    def on_context_state_request(self, context):
+        results = yield self.store.get_current_state_for_context(
+            context
+        )
+
+        logger.debug("Context returning %d results", len(results))
+
+        pdus = [Pdu.from_pdu_tuple(p) for p in results]
+        defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+
+    @defer.inlineCallbacks
+    @log_function
+    def on_pdu_request(self, pdu_origin, pdu_id):
+        pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
+
+        if pdu:
+            defer.returnValue(
+                (200, self._transaction_from_pdus([pdu]).get_dict())
+            )
+        else:
+            defer.returnValue((404, ""))
+
+    @defer.inlineCallbacks
+    @log_function
+    def on_pull_request(self, origin, versions):
+        transaction_id = max([int(v) for v in versions])
+
+        response = yield self.pdu_actions.after_transaction(
+            transaction_id,
+            origin,
+            self.server_name
+        )
+
+        if not response:
+            response = []
+
+        defer.returnValue(
+            (200, self._transaction_from_pdus(response).get_dict())
+        )
+
+    @defer.inlineCallbacks
+    @log_function
+    def _get_persisted_pdu(self, pdu_id, pdu_origin):
+        """ Get a PDU from the database with given origin and id.
+
+        Returns:
+            Deferred: Results in a `Pdu`.
+        """
+        pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
+
+        defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
+
+    def _transaction_from_pdus(self, pdu_list):
+        """Returns a new Transaction containing the given PDUs suitable for
+        transmission.
+        """
+        return Transaction(
+            pdus=[p.get_dict() for p in pdu_list],
+            origin=self.server_name,
+            ts=int(self._clock.time_msec()),
+            destination=None,
+        )
+
+    @defer.inlineCallbacks
+    @log_function
+    def _handle_new_pdu(self, pdu):
+        # We reprocess pdus when we have seen them only as outliers
+        existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
+
+        if existing and (not existing.outlier or pdu.outlier):
+            logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
+            defer.returnValue({})
+            return
+
+        # Get missing pdus if necessary.
+        is_new = yield self.pdu_actions.is_new(pdu)
+        if is_new and not pdu.outlier:
+            # We only paginate backwards to the min depth.
+            min_depth = yield self.store.get_min_depth_for_context(pdu.context)
+
+            if min_depth and pdu.depth > min_depth:
+                for pdu_id, origin in pdu.prev_pdus:
+                    exists = yield self._get_persisted_pdu(pdu_id, origin)
+
+                    if not exists:
+                        logger.debug("Requesting pdu %s %s", pdu_id, origin)
+
+                        try:
+                            yield self.get_pdu(
+                                pdu.origin,
+                                pdu_id=pdu_id,
+                                pdu_origin=origin
+                            )
+                            logger.debug("Processed pdu %s %s", pdu_id, origin)
+                        except:
+                            # TODO(erikj): Do some more intelligent retries.
+                            logger.exception("Failed to get PDU")
+
+        # Persist the Pdu, but don't mark it as processed yet.
+        yield self.pdu_actions.persist_received(pdu)
+
+        ret = yield self.handler.on_receive_pdu(pdu)
+
+        yield self.pdu_actions.mark_as_processed(pdu)
+
+        defer.returnValue(ret)
+
+    def __str__(self):
+        return "<ReplicationLayer(%s)>" % self.server_name
+
+
+class ReplicationHandler(object):
+    """This defines the methods that the :py:class:`.ReplicationLayer` will
+    use to communicate with the rest of the home server.
+    """
+    def on_receive_pdu(self, pdu):
+        raise NotImplementedError("on_receive_pdu")
+
+
+class _TransactionQueue(object):
+    """This class makes sure we only have one transaction in flight at
+    a time for a given destination.
+
+    It batches pending PDUs into single transactions.
+    """
+
+    def __init__(self, hs, transaction_actions, transport_layer):
+
+        self.server_name = hs.hostname
+        self.transaction_actions = transaction_actions
+        self.transport_layer = transport_layer
+
+        self._clock = hs.get_clock()
+
+        # Is a mapping from destinations -> deferreds. Used to keep track
+        # of which destinations have transactions in flight and when they are
+        # done
+        self.pending_transactions = {}
+
+        # Is a mapping from destination -> list of
+        # tuple(pending pdus, deferred, order)
+        self.pending_pdus_by_dest = {}
+        # destination -> list of tuple(edu, deferred)
+        self.pending_edus_by_dest = {}
+
+        # HACK to get unique tx id
+        self._next_txn_id = int(self._clock.time_msec())
+
+    @defer.inlineCallbacks
+    @log_function
+    def enqueue_pdu(self, pdu, order):
+        # We loop through all destinations to see whether we already have
+        # a transaction in progress. If we do, stick it in the pending_pdus
+        # table and we'll get back to it later.
+
+        destinations = [
+            d for d in pdu.destinations
+            if d != self.server_name
+        ]
+
+        logger.debug("Sending to: %s", str(destinations))
+
+        if not destinations:
+            return
+
+        deferreds = []
+
+        for destination in destinations:
+            deferred = defer.Deferred()
+            self.pending_pdus_by_dest.setdefault(destination, []).append(
+                (pdu, deferred, order)
+            )
+
+            self._attempt_new_transaction(destination)
+
+            deferreds.append(deferred)
+
+        yield defer.DeferredList(deferreds)
+
+    # NO inlineCallbacks
+    def enqueue_edu(self, edu):
+        destination = edu.destination
+
+        deferred = defer.Deferred()
+        self.pending_edus_by_dest.setdefault(destination, []).append(
+            (edu, deferred)
+        )
+
+        def eb(failure):
+            deferred.errback(failure)
+        self._attempt_new_transaction(destination).addErrback(eb)
+
+        return deferred
+
+    @defer.inlineCallbacks
+    @log_function
+    def _attempt_new_transaction(self, destination):
+        if destination in self.pending_transactions:
+            return
+
+        #  list of (pending_pdu, deferred, order)
+        pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+        pending_edus = self.pending_edus_by_dest.pop(destination, [])
+
+        if not pending_pdus and not pending_edus:
+            return
+
+        logger.debug("TX [%s] Attempting new transaction", destination)
+
+        # Sort based on the order field
+        pending_pdus.sort(key=lambda t: t[2])
+
+        pdus = [x[0] for x in pending_pdus]
+        edus = [x[0] for x in pending_edus]
+        deferreds = [x[1] for x in pending_pdus + pending_edus]
+
+        try:
+            self.pending_transactions[destination] = 1
+
+            logger.debug("TX [%s] Persisting transaction...", destination)
+
+            transaction = Transaction.create_new(
+                ts=self._clock.time_msec(),
+                transaction_id=self._next_txn_id,
+                origin=self.server_name,
+                destination=destination,
+                pdus=pdus,
+                edus=edus,
+            )
+
+            self._next_txn_id += 1
+
+            yield self.transaction_actions.prepare_to_send(transaction)
+
+            logger.debug("TX [%s] Persisted transaction", destination)
+            logger.debug("TX [%s] Sending transaction...", destination)
+
+            # Actually send the transaction
+            code, response = yield self.transport_layer.send_transaction(
+                transaction
+            )
+
+            logger.debug("TX [%s] Sent transaction", destination)
+            logger.debug("TX [%s] Marking as delivered...", destination)
+
+            yield self.transaction_actions.delivered(
+                transaction, code, response
+            )
+
+            logger.debug("TX [%s] Marked as delivered", destination)
+            logger.debug("TX [%s] Yielding to callbacks...", destination)
+
+            for deferred in deferreds:
+                if code == 200:
+                    deferred.callback(None)
+                else:
+                    deferred.errback(RuntimeError("Got status %d" % code))
+
+                # Ensures we don't continue until all callbacks on that
+                # deferred have fired
+                yield deferred
+
+            logger.debug("TX [%s] Yielded to callbacks", destination)
+
+        except Exception as e:
+            logger.error("TX Problem in _attempt_transaction")
+
+            # We capture this here as there as nothing actually listens
+            # for this finishing functions deferred.
+            logger.exception(e)
+
+            for deferred in deferreds:
+                deferred.errback(e)
+                yield deferred
+
+        finally:
+            # We want to be *very* sure we delete this after we stop processing
+            self.pending_transactions.pop(destination, None)
+
+            # Check to see if there is anything else to send.
+            self._attempt_new_transaction(destination)