diff options
Diffstat (limited to 'synapse/federation/replication.py')
-rw-r--r-- | synapse/federation/replication.py | 66 |
1 files changed, 50 insertions, 16 deletions
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 08c29dece5..d482193851 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -244,13 +244,14 @@ class ReplicationLayer(object): pdu = None if pdu_list: pdu = pdu_list[0] - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdu) @defer.inlineCallbacks @log_function - def get_state_for_context(self, destination, context): + def get_state_for_context(self, destination, context, pdu_id=None, + pdu_origin=None): """Requests all of the `current` state PDUs for a given context from a remote home server. @@ -263,13 +264,14 @@ class ReplicationLayer(object): """ transaction_data = yield self.transport_layer.get_context_state( - destination, context) + destination, context, pdu_id=pdu_id, pdu_origin=pdu_origin, + ) transaction = Transaction(**transaction_data) pdus = [Pdu(outlier=True, **p) for p in transaction.pdus] for pdu in pdus: - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdus) @@ -315,7 +317,7 @@ class ReplicationLayer(object): dl = [] for pdu in pdu_list: - dl.append(self._handle_new_pdu(pdu)) + dl.append(self._handle_new_pdu(transaction.origin, pdu)) if hasattr(transaction, "edus"): for edu in [Edu(**x) for x in transaction.edus]: @@ -347,14 +349,19 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def on_context_state_request(self, context): - results = yield self.store.get_current_state_for_context( - context - ) + def on_context_state_request(self, context, pdu_id, pdu_origin): + if pdu_id and pdu_origin: + pdus = yield self.handler.get_state_for_pdu( + pdu_id, pdu_origin + ) + else: + results = yield self.store.get_current_state_for_context( + context + ) + pdus = [Pdu.from_pdu_tuple(p) for p in results] - logger.debug("Context returning %d results", len(results)) + logger.debug("Context returning %d results", len(pdus)) - pdus = [Pdu.from_pdu_tuple(p) for p in results] defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @defer.inlineCallbacks @@ -396,9 +403,10 @@ class ReplicationLayer(object): defer.returnValue( (404, "No handler for Query type '%s'" % (query_type, )) ) - + @defer.inlineCallbacks def on_make_join_request(self, context, user_id): - return self.handler.on_make_join_request(context, user_id) + pdu = yield self.handler.on_make_join_request(context, user_id) + defer.returnValue(pdu.get_dict()) @defer.inlineCallbacks def on_send_join_request(self, origin, content): @@ -406,13 +414,27 @@ class ReplicationLayer(object): state = yield self.handler.on_send_join_request(origin, pdu) defer.returnValue((200, self._transaction_from_pdus(state).get_dict())) + @defer.inlineCallbacks def make_join(self, destination, context, user_id): - return self.transport_layer.make_join( + pdu_dict = yield self.transport_layer.make_join( destination=destination, context=context, user_id=user_id, ) + logger.debug("Got response to make_join: %s", pdu_dict) + + defer.returnValue(Pdu(**pdu_dict)) + + def send_join(self, destination, pdu): + return self.transport_layer.send_join( + destination, + pdu.context, + pdu.pdu_id, + pdu.origin, + pdu.get_dict(), + ) + @defer.inlineCallbacks @log_function def _get_persisted_pdu(self, pdu_id, pdu_origin): @@ -443,7 +465,7 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def _handle_new_pdu(self, pdu, backfilled=False): + def _handle_new_pdu(self, origin, pdu, backfilled=False): # We reprocess pdus when we have seen them only as outliers existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin) @@ -452,6 +474,8 @@ class ReplicationLayer(object): defer.returnValue({}) return + state = None + # Get missing pdus if necessary. is_new = yield self.pdu_actions.is_new(pdu) if is_new and not pdu.outlier: @@ -475,12 +499,22 @@ class ReplicationLayer(object): except: # TODO(erikj): Do some more intelligent retries. logger.exception("Failed to get PDU") + else: + # We need to get the state at this event, since we have reached + # a backward extremity edge. + state = yield self.get_state_for_context( + origin, pdu.context, pdu.pdu_id, pdu.origin, + ) # Persist the Pdu, but don't mark it as processed yet. yield self.store.persist_event(pdu=pdu) if not backfilled: - ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled) + ret = yield self.handler.on_receive_pdu( + pdu, + backfilled=backfilled, + state=state, + ) else: ret = None |