summary refs log tree commit diff
path: root/synapse/federation/replication.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/replication.py')
-rw-r--r--synapse/federation/replication.py66
1 files changed, 50 insertions, 16 deletions
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 08c29dece5..d482193851 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -244,13 +244,14 @@ class ReplicationLayer(object):
         pdu = None
         if pdu_list:
             pdu = pdu_list[0]
-            yield self._handle_new_pdu(pdu)
+            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdu)
 
     @defer.inlineCallbacks
     @log_function
-    def get_state_for_context(self, destination, context):
+    def get_state_for_context(self, destination, context, pdu_id=None,
+                              pdu_origin=None):
         """Requests all of the `current` state PDUs for a given context from
         a remote home server.
 
@@ -263,13 +264,14 @@ class ReplicationLayer(object):
         """
 
         transaction_data = yield self.transport_layer.get_context_state(
-            destination, context)
+            destination, context, pdu_id=pdu_id, pdu_origin=pdu_origin,
+        )
 
         transaction = Transaction(**transaction_data)
 
         pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
         for pdu in pdus:
-            yield self._handle_new_pdu(pdu)
+            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdus)
 
@@ -315,7 +317,7 @@ class ReplicationLayer(object):
 
         dl = []
         for pdu in pdu_list:
-            dl.append(self._handle_new_pdu(pdu))
+            dl.append(self._handle_new_pdu(transaction.origin, pdu))
 
         if hasattr(transaction, "edus"):
             for edu in [Edu(**x) for x in transaction.edus]:
@@ -347,14 +349,19 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def on_context_state_request(self, context):
-        results = yield self.store.get_current_state_for_context(
-            context
-        )
+    def on_context_state_request(self, context, pdu_id, pdu_origin):
+        if pdu_id and pdu_origin:
+            pdus = yield self.handler.get_state_for_pdu(
+                pdu_id, pdu_origin
+            )
+        else:
+            results = yield self.store.get_current_state_for_context(
+                context
+            )
+            pdus = [Pdu.from_pdu_tuple(p) for p in results]
 
-        logger.debug("Context returning %d results", len(results))
+        logger.debug("Context returning %d results", len(pdus))
 
-        pdus = [Pdu.from_pdu_tuple(p) for p in results]
         defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
 
     @defer.inlineCallbacks
@@ -396,9 +403,10 @@ class ReplicationLayer(object):
             defer.returnValue(
                 (404, "No handler for Query type '%s'" % (query_type, ))
             )
-
+    @defer.inlineCallbacks
     def on_make_join_request(self, context, user_id):
-        return self.handler.on_make_join_request(context, user_id)
+        pdu = yield self.handler.on_make_join_request(context, user_id)
+        defer.returnValue(pdu.get_dict())
 
     @defer.inlineCallbacks
     def on_send_join_request(self, origin, content):
@@ -406,13 +414,27 @@ class ReplicationLayer(object):
         state = yield self.handler.on_send_join_request(origin, pdu)
         defer.returnValue((200, self._transaction_from_pdus(state).get_dict()))
 
+    @defer.inlineCallbacks
     def make_join(self, destination, context, user_id):
-        return self.transport_layer.make_join(
+        pdu_dict = yield self.transport_layer.make_join(
             destination=destination,
             context=context,
             user_id=user_id,
         )
 
+        logger.debug("Got response to make_join: %s", pdu_dict)
+
+        defer.returnValue(Pdu(**pdu_dict))
+
+    def send_join(self, destination, pdu):
+        return self.transport_layer.send_join(
+            destination,
+            pdu.context,
+            pdu.pdu_id,
+            pdu.origin,
+            pdu.get_dict(),
+        )
+
     @defer.inlineCallbacks
     @log_function
     def _get_persisted_pdu(self, pdu_id, pdu_origin):
@@ -443,7 +465,7 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def _handle_new_pdu(self, pdu, backfilled=False):
+    def _handle_new_pdu(self, origin, pdu, backfilled=False):
         # We reprocess pdus when we have seen them only as outliers
         existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
 
@@ -452,6 +474,8 @@ class ReplicationLayer(object):
             defer.returnValue({})
             return
 
+        state = None
+
         # Get missing pdus if necessary.
         is_new = yield self.pdu_actions.is_new(pdu)
         if is_new and not pdu.outlier:
@@ -475,12 +499,22 @@ class ReplicationLayer(object):
                         except:
                             # TODO(erikj): Do some more intelligent retries.
                             logger.exception("Failed to get PDU")
+            else:
+                # We need to get the state at this event, since we have reached
+                # a backward extremity edge.
+                state = yield self.get_state_for_context(
+                    origin, pdu.context, pdu.pdu_id, pdu.origin,
+                )
 
         # Persist the Pdu, but don't mark it as processed yet.
         yield self.store.persist_event(pdu=pdu)
 
         if not backfilled:
-            ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+            ret = yield self.handler.on_receive_pdu(
+                pdu,
+                backfilled=backfilled,
+                state=state,
+            )
         else:
             ret = None