summary refs log tree commit diff
path: root/synapse/federation/replication.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/replication.py')
-rw-r--r--synapse/federation/replication.py117
1 files changed, 91 insertions, 26 deletions
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 65a53ae17c..6bfb30b42d 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -24,6 +24,7 @@ from .units import Transaction, Edu
 from .persistence import TransactionActions
 
 from synapse.util.logutils import log_function
+from synapse.util.logcontext import PreserveLoggingContext
 
 import logging
 
@@ -319,19 +320,20 @@ class ReplicationLayer(object):
 
         logger.debug("[%s] Transacition is new", transaction.transaction_id)
 
-        dl = []
-        for pdu in pdu_list:
-            dl.append(self._handle_new_pdu(transaction.origin, pdu))
+        with PreserveLoggingContext():
+            dl = []
+            for pdu in pdu_list:
+                dl.append(self._handle_new_pdu(transaction.origin, pdu))
 
-        if hasattr(transaction, "edus"):
-            for edu in [Edu(**x) for x in transaction.edus]:
-                self.received_edu(
-                    transaction.origin,
-                    edu.edu_type,
-                    edu.content
-                )
+            if hasattr(transaction, "edus"):
+                for edu in [Edu(**x) for x in transaction.edus]:
+                    self.received_edu(
+                        transaction.origin,
+                        edu.edu_type,
+                        edu.content
+                    )
 
-        results = yield defer.DeferredList(dl)
+            results = yield defer.DeferredList(dl)
 
         ret = []
         for r in results:
@@ -425,7 +427,9 @@ class ReplicationLayer(object):
         time_now = self._clock.time_msec()
         defer.returnValue((200, {
             "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
-            "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+            "auth_chain": [
+                p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
+            ],
         }))
 
     @defer.inlineCallbacks
@@ -436,7 +440,9 @@ class ReplicationLayer(object):
             (
                 200,
                 {
-                    "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
+                    "auth_chain": [
+                        a.get_pdu_json(time_now) for a in auth_pdus
+                    ],
                 }
             )
         )
@@ -457,7 +463,7 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     def send_join(self, destination, pdu):
-        time_now  = self._clock.time_msec()
+        time_now = self._clock.time_msec()
         _, content = yield self.transport_layer.send_join(
             destination,
             pdu.room_id,
@@ -475,11 +481,17 @@ class ReplicationLayer(object):
         # FIXME: We probably want to do something with the auth_chain given
         # to us
 
-        # auth_chain = [
-        #    Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
-        # ]
+        auth_chain = [
+            self.event_from_pdu_json(p, outlier=True)
+            for p in content.get("auth_chain", [])
+        ]
 
-        defer.returnValue(state)
+        auth_chain.sort(key=lambda e: e.depth)
+
+        defer.returnValue({
+            "state": state,
+            "auth_chain": auth_chain,
+        })
 
     @defer.inlineCallbacks
     def send_invite(self, destination, context, event_id, pdu):
@@ -498,13 +510,15 @@ class ReplicationLayer(object):
         defer.returnValue(self.event_from_pdu_json(pdu_dict))
 
     @log_function
-    def _get_persisted_pdu(self, origin, event_id):
+    def _get_persisted_pdu(self, origin, event_id, do_auth=True):
         """ Get a PDU from the database with given origin and id.
 
         Returns:
             Deferred: Results in a `Pdu`.
         """
-        return self.handler.get_persisted_pdu(origin, event_id)
+        return self.handler.get_persisted_pdu(
+            origin, event_id, do_auth=do_auth
+        )
 
     def _transaction_from_pdus(self, pdu_list):
         """Returns a new Transaction containing the given PDUs suitable for
@@ -523,7 +537,9 @@ class ReplicationLayer(object):
     @log_function
     def _handle_new_pdu(self, origin, pdu, backfilled=False):
         # We reprocess pdus when we have seen them only as outliers
-        existing = yield self._get_persisted_pdu(origin, pdu.event_id)
+        existing = yield self._get_persisted_pdu(
+            origin, pdu.event_id, do_auth=False
+        )
 
         if existing and (not existing.outlier or pdu.outlier):
             logger.debug("Already seen pdu %s", pdu.event_id)
@@ -532,6 +548,36 @@ class ReplicationLayer(object):
 
         state = None
 
+        # We need to make sure we have all the auth events.
+        for e_id, _ in pdu.auth_events:
+            exists = yield self._get_persisted_pdu(
+                origin,
+                e_id,
+                do_auth=False
+            )
+
+            if not exists:
+                try:
+                    logger.debug(
+                        "_handle_new_pdu fetch missing auth event %s from %s",
+                        e_id,
+                        origin,
+                    )
+
+                    yield self.get_pdu(
+                        origin,
+                        event_id=e_id,
+                        outlier=True,
+                    )
+
+                    logger.debug("Processed pdu %s", e_id)
+                except:
+                    logger.warn(
+                        "Failed to get auth event %s from %s",
+                        e_id,
+                        origin
+                    )
+
         # Get missing pdus if necessary.
         if not pdu.outlier:
             # We only backfill backwards to the min depth.
@@ -539,16 +585,28 @@ class ReplicationLayer(object):
                 pdu.room_id
             )
 
+            logger.debug(
+                "_handle_new_pdu min_depth for %s: %d",
+                pdu.room_id, min_depth
+            )
+
             if min_depth and pdu.depth > min_depth:
                 for event_id, hashes in pdu.prev_events:
-                    exists = yield self._get_persisted_pdu(origin, event_id)
+                    exists = yield self._get_persisted_pdu(
+                        origin,
+                        event_id,
+                        do_auth=False
+                    )
 
                     if not exists:
-                        logger.debug("Requesting pdu %s", event_id)
+                        logger.debug(
+                            "_handle_new_pdu requesting pdu %s",
+                            event_id
+                        )
 
                         try:
                             yield self.get_pdu(
-                                pdu.origin,
+                                origin,
                                 event_id=event_id,
                             )
                             logger.debug("Processed pdu %s", event_id)
@@ -558,6 +616,10 @@ class ReplicationLayer(object):
             else:
                 # We need to get the state at this event, since we have reached
                 # a backward extremity edge.
+                logger.debug(
+                    "_handle_new_pdu getting state for %s",
+                    pdu.room_id
+                )
                 state = yield self.get_state_for_context(
                     origin, pdu.room_id, pdu.event_id,
                 )
@@ -649,7 +711,8 @@ class _TransactionQueue(object):
                 (pdu, deferred, order)
             )
 
-            self._attempt_new_transaction(destination)
+            with PreserveLoggingContext():
+                self._attempt_new_transaction(destination)
 
             deferreds.append(deferred)
 
@@ -669,7 +732,9 @@ class _TransactionQueue(object):
                 deferred.errback(failure)
             else:
                 logger.exception("Failed to send edu", failure)
-        self._attempt_new_transaction(destination).addErrback(eb)
+
+        with PreserveLoggingContext():
+            self._attempt_new_transaction(destination).addErrback(eb)
 
         return deferred