summary refs log tree commit diff
path: root/synapse/federation/replication.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/replication.py')
-rw-r--r--synapse/federation/replication.py286
1 files changed, 186 insertions, 100 deletions
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 092411eaf9..92a9678e2c 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
 
 from .units import Transaction, Pdu, Edu
 
-from .persistence import PduActions, TransactionActions
+from .persistence import TransactionActions
 
 from synapse.util.logutils import log_function
 
@@ -57,7 +57,7 @@ class ReplicationLayer(object):
         self.transport_layer.register_request_handler(self)
 
         self.store = hs.get_datastore()
-        self.pdu_actions = PduActions(self.store)
+        # self.pdu_actions = PduActions(self.store)
         self.transaction_actions = TransactionActions(self.store)
 
         self._transaction_queue = _TransactionQueue(
@@ -81,7 +81,7 @@ class ReplicationLayer(object):
 
     def register_edu_handler(self, edu_type, handler):
         if edu_type in self.edu_handlers:
-            raise KeyError("Already have an EDU handler for %s" % (edu_type))
+            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
 
         self.edu_handlers[edu_type] = handler
 
@@ -102,24 +102,17 @@ class ReplicationLayer(object):
           object to encode as JSON.
         """
         if query_type in self.query_handlers:
-            raise KeyError("Already have a Query handler for %s" % (query_type))
+            raise KeyError(
+                "Already have a Query handler for %s" % (query_type,)
+            )
 
         self.query_handlers[query_type] = handler
 
-    @defer.inlineCallbacks
     @log_function
     def send_pdu(self, pdu):
         """Informs the replication layer about a new PDU generated within the
         home server that should be transmitted to others.
 
-        This will fill out various attributes on the PDU object, e.g. the
-        `prev_pdus` key.
-
-        *Note:* The home server should always call `send_pdu` even if it knows
-        that it does not need to be replicated to other home servers. This is
-        in case e.g. someone else joins via a remote home server and then
-        backfills.
-
         TODO: Figure out when we should actually resolve the deferred.
 
         Args:
@@ -132,18 +125,15 @@ class ReplicationLayer(object):
         order = self._order
         self._order += 1
 
-        logger.debug("[%s] Persisting PDU", pdu.pdu_id)
-
-        # Save *before* trying to send
-        yield self.store.persist_event(pdu=pdu)
-
-        logger.debug("[%s] Persisted PDU", pdu.pdu_id)
-        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
+        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
 
         # TODO, add errback, etc.
         self._transaction_queue.enqueue_pdu(pdu, order)
 
-        logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
+        logger.debug(
+            "[%s] transaction_layer.enqueue_pdu... done",
+            pdu.event_id
+        )
 
     @log_function
     def send_edu(self, destination, edu_type, content):
@@ -159,6 +149,11 @@ class ReplicationLayer(object):
         return defer.succeed(None)
 
     @log_function
+    def send_failure(self, failure, destination):
+        self._transaction_queue.enqueue_failure(failure, destination)
+        return defer.succeed(None)
+
+    @log_function
     def make_query(self, destination, query_type, args,
                    retry_on_dns_fail=True):
         """Sends a federation Query to a remote homeserver of the given type
@@ -181,7 +176,7 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def backfill(self, dest, context, limit):
+    def backfill(self, dest, context, limit, extremities):
         """Requests some more historic PDUs for the given context from the
         given destination server.
 
@@ -189,12 +184,12 @@ class ReplicationLayer(object):
             dest (str): The remote home server to ask.
             context (str): The context to backfill.
             limit (int): The maximum number of PDUs to return.
+            extremities (list): List of PDU id and origins of the first pdus
+                we have seen from the context
 
         Returns:
             Deferred: Results in the received PDUs.
         """
-        extremities = yield self.store.get_oldest_pdus_in_context(context)
-
         logger.debug("backfill extrem=%s", extremities)
 
         # If there are no extremeties then we've (probably) reached the start.
@@ -216,7 +211,7 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
+    def get_pdu(self, destination, event_id, outlier=False):
         """Requests the PDU with given origin and ID from the remote home
         server.
 
@@ -225,7 +220,7 @@ class ReplicationLayer(object):
         Args:
             destination (str): Which home server to query
             pdu_origin (str): The home server that originally sent the pdu.
-            pdu_id (str)
+            event_id (str)
             outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                 it's from an arbitary point in the context as opposed to part
                 of the current block of PDUs. Defaults to `False`
@@ -234,8 +229,9 @@ class ReplicationLayer(object):
             Deferred: Results in the requested PDU.
         """
 
-        transaction_data = yield self.transport_layer.get_pdu(
-            destination, pdu_origin, pdu_id)
+        transaction_data = yield self.transport_layer.get_event(
+            destination, event_id
+        )
 
         transaction = Transaction(**transaction_data)
 
@@ -244,13 +240,13 @@ class ReplicationLayer(object):
         pdu = None
         if pdu_list:
             pdu = pdu_list[0]
-            yield self._handle_new_pdu(pdu)
+            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdu)
 
     @defer.inlineCallbacks
     @log_function
-    def get_state_for_context(self, destination, context):
+    def get_state_for_context(self, destination, context, event_id=None):
         """Requests all of the `current` state PDUs for a given context from
         a remote home server.
 
@@ -263,29 +259,32 @@ class ReplicationLayer(object):
         """
 
         transaction_data = yield self.transport_layer.get_context_state(
-            destination, context)
+            destination,
+            context,
+            event_id=event_id,
+        )
 
         transaction = Transaction(**transaction_data)
 
         pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
         for pdu in pdus:
-            yield self._handle_new_pdu(pdu)
+            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdus)
 
     @defer.inlineCallbacks
     @log_function
     def on_context_pdus_request(self, context):
-        pdus = yield self.pdu_actions.get_all_pdus_from_context(
-            context
+        raise NotImplementedError(
+            "on_context_pdus_request is a security violation"
         )
-        defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
 
     @defer.inlineCallbacks
     @log_function
     def on_backfill_request(self, context, versions, limit):
-
-        pdus = yield self.pdu_actions.backfill(context, versions, limit)
+        pdus = yield self.handler.on_backfill_request(
+            context, versions, limit
+        )
 
         defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
 
@@ -295,6 +294,10 @@ class ReplicationLayer(object):
         transaction = Transaction(**transaction_data)
 
         for p in transaction.pdus:
+            if "unsigned" in p:
+                unsigned = p["unsigned"]
+                if "age" in unsigned:
+                    p["age"] = unsigned["age"]
             if "age" in p:
                 p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
                 del p["age"]
@@ -315,11 +318,15 @@ class ReplicationLayer(object):
 
         dl = []
         for pdu in pdu_list:
-            dl.append(self._handle_new_pdu(pdu))
+            dl.append(self._handle_new_pdu(transaction.origin, pdu))
 
         if hasattr(transaction, "edus"):
             for edu in [Edu(**x) for x in transaction.edus]:
-                self.received_edu(transaction.origin, edu.edu_type, edu.content)
+                self.received_edu(
+                    transaction.origin,
+                    edu.edu_type,
+                    edu.content
+                )
 
         results = yield defer.DeferredList(dl)
 
@@ -347,20 +354,26 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def on_context_state_request(self, context):
-        results = yield self.store.get_current_state_for_context(
-            context
-        )
-
-        logger.debug("Context returning %d results", len(results))
+    def on_context_state_request(self, context, event_id):
+        if event_id:
+            pdus = yield self.handler.get_state_for_pdu(
+                event_id
+            )
+        else:
+            raise NotImplementedError("Specify an event")
+        #     results = yield self.store.get_current_state_for_context(
+        #         context
+        #     )
+        #     pdus = [Pdu.from_pdu_tuple(p) for p in results]
+        #
+        # logger.debug("Context returning %d results", len(pdus))
 
-        pdus = [Pdu.from_pdu_tuple(p) for p in results]
         defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
 
     @defer.inlineCallbacks
     @log_function
-    def on_pdu_request(self, pdu_origin, pdu_id):
-        pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
+    def on_pdu_request(self, event_id):
+        pdu = yield self._get_persisted_pdu(event_id)
 
         if pdu:
             defer.returnValue(
@@ -372,20 +385,22 @@ class ReplicationLayer(object):
     @defer.inlineCallbacks
     @log_function
     def on_pull_request(self, origin, versions):
-        transaction_id = max([int(v) for v in versions])
-
-        response = yield self.pdu_actions.after_transaction(
-            transaction_id,
-            origin,
-            self.server_name
-        )
-
-        if not response:
-            response = []
-
-        defer.returnValue(
-            (200, self._transaction_from_pdus(response).get_dict())
-        )
+        raise NotImplementedError("Pull transacions not implemented")
+
+        # transaction_id = max([int(v) for v in versions])
+        #
+        # response = yield self.pdu_actions.after_transaction(
+        #     transaction_id,
+        #     origin,
+        #     self.server_name
+        # )
+        #
+        # if not response:
+        #     response = []
+        #
+        # defer.returnValue(
+        #     (200, self._transaction_from_pdus(response).get_dict())
+        # )
 
     @defer.inlineCallbacks
     def on_query_request(self, query_type, args):
@@ -393,82 +408,138 @@ class ReplicationLayer(object):
             response = yield self.query_handlers[query_type](args)
             defer.returnValue((200, response))
         else:
-            defer.returnValue((404, "No handler for Query type '%s'"
-                % (query_type)
-            ))
+            defer.returnValue(
+                (404, "No handler for Query type '%s'" % (query_type, ))
+            )
+
+    @defer.inlineCallbacks
+    def on_make_join_request(self, context, user_id):
+        pdu = yield self.handler.on_make_join_request(context, user_id)
+        defer.returnValue(pdu.get_dict())
+
+    @defer.inlineCallbacks
+    def on_invite_request(self, origin, content):
+        pdu = Pdu(**content)
+        ret_pdu = yield self.handler.on_send_join_request(origin, pdu)
+        defer.returnValue((200, ret_pdu.get_dict()))
 
     @defer.inlineCallbacks
+    def on_send_join_request(self, origin, content):
+        pdu = Pdu(**content)
+        state = yield self.handler.on_send_join_request(origin, pdu)
+        defer.returnValue((200, self._transaction_from_pdus(state).get_dict()))
+
+    @defer.inlineCallbacks
+    def make_join(self, destination, context, user_id):
+        pdu_dict = yield self.transport_layer.make_join(
+            destination=destination,
+            context=context,
+            user_id=user_id,
+        )
+
+        logger.debug("Got response to make_join: %s", pdu_dict)
+
+        defer.returnValue(Pdu(**pdu_dict))
+
+    @defer.inlineCallbacks
+    def send_join(self, destination, pdu):
+        _, content = yield self.transport_layer.send_join(
+            destination,
+            pdu.room_id,
+            pdu.event_id,
+            pdu.get_dict(),
+        )
+
+        logger.debug("Got content: %s", content)
+        pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])]
+        for pdu in pdus:
+            yield self._handle_new_pdu(destination, pdu)
+
+        defer.returnValue(pdus)
+
     @log_function
-    def _get_persisted_pdu(self, pdu_id, pdu_origin):
+    def _get_persisted_pdu(self, event_id):
         """ Get a PDU from the database with given origin and id.
 
         Returns:
             Deferred: Results in a `Pdu`.
         """
-        pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
-
-        defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
+        return self.handler.get_persisted_pdu(event_id)
 
     def _transaction_from_pdus(self, pdu_list):
         """Returns a new Transaction containing the given PDUs suitable for
         transmission.
         """
         pdus = [p.get_dict() for p in pdu_list]
+        time_now = self._clock.time_msec()
         for p in pdus:
-            if "age_ts" in pdus:
-                p["age"] = int(self.clock.time_msec()) - p["age_ts"]
-
+            if "age_ts" in p:
+                age = time_now - p["age_ts"]
+                p.setdefault("unsigned", {})["age"] = int(age)
+                del p["age_ts"]
         return Transaction(
             origin=self.server_name,
             pdus=pdus,
-            origin_server_ts=int(self._clock.time_msec()),
+            origin_server_ts=int(time_now),
             destination=None,
         )
 
     @defer.inlineCallbacks
     @log_function
-    def _handle_new_pdu(self, pdu, backfilled=False):
+    def _handle_new_pdu(self, origin, pdu, backfilled=False):
         # We reprocess pdus when we have seen them only as outliers
-        existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
+        existing = yield self._get_persisted_pdu(pdu.event_id)
 
         if existing and (not existing.outlier or pdu.outlier):
-            logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
+            logger.debug("Already seen pdu %s", pdu.event_id)
             defer.returnValue({})
             return
 
+        state = None
+
         # Get missing pdus if necessary.
-        is_new = yield self.pdu_actions.is_new(pdu)
-        if is_new and not pdu.outlier:
+        if not pdu.outlier:
             # We only backfill backwards to the min depth.
-            min_depth = yield self.store.get_min_depth_for_context(pdu.context)
+            min_depth = yield self.handler.get_min_depth_for_context(
+                pdu.room_id
+            )
 
             if min_depth and pdu.depth > min_depth:
-                for pdu_id, origin in pdu.prev_pdus:
-                    exists = yield self._get_persisted_pdu(pdu_id, origin)
+                for event_id, hashes in pdu.prev_events:
+                    exists = yield self._get_persisted_pdu(event_id)
 
                     if not exists:
-                        logger.debug("Requesting pdu %s %s", pdu_id, origin)
+                        logger.debug("Requesting pdu %s", event_id)
 
                         try:
                             yield self.get_pdu(
                                 pdu.origin,
-                                pdu_id=pdu_id,
-                                pdu_origin=origin
+                                event_id=event_id,
                             )
-                            logger.debug("Processed pdu %s %s", pdu_id, origin)
+                            logger.debug("Processed pdu %s", event_id)
                         except:
                             # TODO(erikj): Do some more intelligent retries.
                             logger.exception("Failed to get PDU")
+            else:
+                # We need to get the state at this event, since we have reached
+                # a backward extremity edge.
+                state = yield self.get_state_for_context(
+                    origin, pdu.room_id, pdu.event_id,
+                )
 
         # Persist the Pdu, but don't mark it as processed yet.
-        yield self.store.persist_event(pdu=pdu)
+        # yield self.store.persist_event(pdu=pdu)
 
         if not backfilled:
-            ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+            ret = yield self.handler.on_receive_pdu(
+                pdu,
+                backfilled=backfilled,
+                state=state,
+            )
         else:
             ret = None
 
-        yield self.pdu_actions.mark_as_processed(pdu)
+        # yield self.pdu_actions.mark_as_processed(pdu)
 
         defer.returnValue(ret)
 
@@ -476,14 +547,6 @@ class ReplicationLayer(object):
         return "<ReplicationLayer(%s)>" % self.server_name
 
 
-class ReplicationHandler(object):
-    """This defines the methods that the :py:class:`.ReplicationLayer` will
-    use to communicate with the rest of the home server.
-    """
-    def on_receive_pdu(self, pdu):
-        raise NotImplementedError("on_receive_pdu")
-
-
 class _TransactionQueue(object):
     """This class makes sure we only have one transaction in flight at
     a time for a given destination.
@@ -509,6 +572,9 @@ class _TransactionQueue(object):
         # destination -> list of tuple(edu, deferred)
         self.pending_edus_by_dest = {}
 
+        # destination -> list of tuple(failure, deferred)
+        self.pending_failures_by_dest = {}
+
         # HACK to get unique tx id
         self._next_txn_id = int(self._clock.time_msec())
 
@@ -562,6 +628,18 @@ class _TransactionQueue(object):
         return deferred
 
     @defer.inlineCallbacks
+    def enqueue_failure(self, failure, destination):
+        deferred = defer.Deferred()
+
+        self.pending_failures_by_dest.setdefault(
+            destination, []
+        ).append(
+            (failure, deferred)
+        )
+
+        yield deferred
+
+    @defer.inlineCallbacks
     @log_function
     def _attempt_new_transaction(self, destination):
         if destination in self.pending_transactions:
@@ -570,8 +648,9 @@ class _TransactionQueue(object):
         #  list of (pending_pdu, deferred, order)
         pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
         pending_edus = self.pending_edus_by_dest.pop(destination, [])
+        pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
-        if not pending_pdus and not pending_edus:
+        if not pending_pdus and not pending_edus and not pending_failures:
             return
 
         logger.debug("TX [%s] Attempting new transaction", destination)
@@ -581,7 +660,11 @@ class _TransactionQueue(object):
 
         pdus = [x[0] for x in pending_pdus]
         edus = [x[0] for x in pending_edus]
-        deferreds = [x[1] for x in pending_pdus + pending_edus]
+        failures = [x[0].get_dict() for x in pending_failures]
+        deferreds = [
+            x[1]
+            for x in pending_pdus + pending_edus + pending_failures
+        ]
 
         try:
             self.pending_transactions[destination] = 1
@@ -589,12 +672,13 @@ class _TransactionQueue(object):
             logger.debug("TX [%s] Persisting transaction...", destination)
 
             transaction = Transaction.create_new(
-                origin_server_ts=self._clock.time_msec(),
+                origin_server_ts=int(self._clock.time_msec()),
                 transaction_id=str(self._next_txn_id),
                 origin=self.server_name,
                 destination=destination,
                 pdus=pdus,
                 edus=edus,
+                pdu_failures=failures,
             )
 
             self._next_txn_id += 1
@@ -614,7 +698,9 @@ class _TransactionQueue(object):
                 if "pdus" in data:
                     for p in data["pdus"]:
                         if "age_ts" in p:
-                            p["age"] = now - int(p["age_ts"])
+                            unsigned = p.setdefault("unsigned", {})
+                            unsigned["age"] = now - int(p["age_ts"])
+                            del p["age_ts"]
                 return data
 
             code, response = yield self.transport_layer.send_transaction(