summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/replication.py66
-rw-r--r--synapse/federation/transport.py68
2 files changed, 110 insertions, 24 deletions
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 08c29dece5..d482193851 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -244,13 +244,14 @@ class ReplicationLayer(object):
         pdu = None
         if pdu_list:
             pdu = pdu_list[0]
-            yield self._handle_new_pdu(pdu)
+            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdu)
 
     @defer.inlineCallbacks
     @log_function
-    def get_state_for_context(self, destination, context):
+    def get_state_for_context(self, destination, context, pdu_id=None,
+                              pdu_origin=None):
         """Requests all of the `current` state PDUs for a given context from
         a remote home server.
 
@@ -263,13 +264,14 @@ class ReplicationLayer(object):
         """
 
         transaction_data = yield self.transport_layer.get_context_state(
-            destination, context)
+            destination, context, pdu_id=pdu_id, pdu_origin=pdu_origin,
+        )
 
         transaction = Transaction(**transaction_data)
 
         pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
         for pdu in pdus:
-            yield self._handle_new_pdu(pdu)
+            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdus)
 
@@ -315,7 +317,7 @@ class ReplicationLayer(object):
 
         dl = []
         for pdu in pdu_list:
-            dl.append(self._handle_new_pdu(pdu))
+            dl.append(self._handle_new_pdu(transaction.origin, pdu))
 
         if hasattr(transaction, "edus"):
             for edu in [Edu(**x) for x in transaction.edus]:
@@ -347,14 +349,19 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def on_context_state_request(self, context):
-        results = yield self.store.get_current_state_for_context(
-            context
-        )
+    def on_context_state_request(self, context, pdu_id, pdu_origin):
+        if pdu_id and pdu_origin:
+            pdus = yield self.handler.get_state_for_pdu(
+                pdu_id, pdu_origin
+            )
+        else:
+            results = yield self.store.get_current_state_for_context(
+                context
+            )
+            pdus = [Pdu.from_pdu_tuple(p) for p in results]
 
-        logger.debug("Context returning %d results", len(results))
+        logger.debug("Context returning %d results", len(pdus))
 
-        pdus = [Pdu.from_pdu_tuple(p) for p in results]
         defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
 
     @defer.inlineCallbacks
@@ -396,9 +403,10 @@ class ReplicationLayer(object):
             defer.returnValue(
                 (404, "No handler for Query type '%s'" % (query_type, ))
             )
-
+    @defer.inlineCallbacks
     def on_make_join_request(self, context, user_id):
-        return self.handler.on_make_join_request(context, user_id)
+        pdu = yield self.handler.on_make_join_request(context, user_id)
+        defer.returnValue(pdu.get_dict())
 
     @defer.inlineCallbacks
     def on_send_join_request(self, origin, content):
@@ -406,13 +414,27 @@ class ReplicationLayer(object):
         state = yield self.handler.on_send_join_request(origin, pdu)
         defer.returnValue((200, self._transaction_from_pdus(state).get_dict()))
 
+    @defer.inlineCallbacks
     def make_join(self, destination, context, user_id):
-        return self.transport_layer.make_join(
+        pdu_dict = yield self.transport_layer.make_join(
             destination=destination,
             context=context,
             user_id=user_id,
         )
 
+        logger.debug("Got response to make_join: %s", pdu_dict)
+
+        defer.returnValue(Pdu(**pdu_dict))
+
+    def send_join(self, destination, pdu):
+        return self.transport_layer.send_join(
+            destination,
+            pdu.context,
+            pdu.pdu_id,
+            pdu.origin,
+            pdu.get_dict(),
+        )
+
     @defer.inlineCallbacks
     @log_function
     def _get_persisted_pdu(self, pdu_id, pdu_origin):
@@ -443,7 +465,7 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def _handle_new_pdu(self, pdu, backfilled=False):
+    def _handle_new_pdu(self, origin, pdu, backfilled=False):
         # We reprocess pdus when we have seen them only as outliers
         existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
 
@@ -452,6 +474,8 @@ class ReplicationLayer(object):
             defer.returnValue({})
             return
 
+        state = None
+
         # Get missing pdus if necessary.
         is_new = yield self.pdu_actions.is_new(pdu)
         if is_new and not pdu.outlier:
@@ -475,12 +499,22 @@ class ReplicationLayer(object):
                         except:
                             # TODO(erikj): Do some more intelligent retries.
                             logger.exception("Failed to get PDU")
+            else:
+                # We need to get the state at this event, since we have reached
+                # a backward extremity edge.
+                state = yield self.get_state_for_context(
+                    origin, pdu.context, pdu.pdu_id, pdu.origin,
+                )
 
         # Persist the Pdu, but don't mark it as processed yet.
         yield self.store.persist_event(pdu=pdu)
 
         if not backfilled:
-            ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+            ret = yield self.handler.on_receive_pdu(
+                pdu,
+                backfilled=backfilled,
+                state=state,
+            )
         else:
             ret = None
 
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index 4f552272e6..a0d34fd24d 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -72,7 +72,8 @@ class TransportLayer(object):
         self.received_handler = None
 
     @log_function
-    def get_context_state(self, destination, context):
+    def get_context_state(self, destination, context, pdu_id=None,
+                          pdu_origin=None):
         """ Requests all state for a given context (i.e. room) from the
         given server.
 
@@ -89,7 +90,14 @@ class TransportLayer(object):
 
         subpath = "/state/%s/" % context
 
-        return self._do_request_for_transaction(destination, subpath)
+        args = {}
+        if pdu_id and pdu_origin:
+            args["pdu_id"] = pdu_id
+            args["pdu_origin"] = pdu_origin
+
+        return self._do_request_for_transaction(
+            destination, subpath, args=args
+        )
 
     @log_function
     def get_pdu(self, destination, pdu_origin, pdu_id):
@@ -135,8 +143,10 @@ class TransportLayer(object):
 
         subpath = "/backfill/%s/" % context
 
-        args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
-        args["limit"] = limit
+        args = {
+            "v": ["%s,%s" % (i, o) for i, o in pdu_tuples],
+            "limit": limit,
+        }
 
         return self._do_request_for_transaction(
             dest,
@@ -211,6 +221,23 @@ class TransportLayer(object):
         defer.returnValue(response)
 
     @defer.inlineCallbacks
+    @log_function
+    def send_join(self, destination, context, pdu_id, origin, content):
+        path = PREFIX + "/send_join/%s/%s/%s" % (
+            context,
+            origin,
+            pdu_id,
+        )
+
+        response = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+        )
+
+        defer.returnValue(response)
+
+    @defer.inlineCallbacks
     def _authenticate_request(self, request):
         json_request = {
             "method": request.method,
@@ -330,7 +357,11 @@ class TransportLayer(object):
             re.compile("^" + PREFIX + "/state/([^/]*)/$"),
             self._with_authentication(
                 lambda origin, content, query, context:
-                handler.on_context_state_request(context)
+                handler.on_context_state_request(
+                    context,
+                    query.get("pdu_id", [None])[0],
+                    query.get("pdu_origin", [None])[0]
+                )
             )
         )
 
@@ -369,7 +400,23 @@ class TransportLayer(object):
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
-            self._on_make_join_request
+            self._with_authentication(
+                lambda origin, content, query, context, user_id:
+                self._on_make_join_request(
+                    origin, content, query, context, user_id
+                )
+            )
+        )
+
+        self.server.register_path(
+            "PUT",
+            re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, pdu_origin, pdu_id:
+                self._on_send_join_request(
+                    origin, content, query,
+                )
+            )
         )
 
     @defer.inlineCallbacks
@@ -460,18 +507,23 @@ class TransportLayer(object):
             context, versions, limit
         )
 
+    @defer.inlineCallbacks
     @log_function
     def _on_make_join_request(self, origin, content, query, context, user_id):
-        return self.request_handler.on_make_join_request(
+        content = yield self.request_handler.on_make_join_request(
             context, user_id,
         )
+        defer.returnValue((200, content))
 
+    @defer.inlineCallbacks
     @log_function
     def _on_send_join_request(self, origin, content, query):
-        return self.request_handler.on_send_join_request(
+        content = yield self.request_handler.on_send_join_request(
             origin, content,
         )
 
+        defer.returnValue((200, content))
+
 
 class TransportReceivedHandler(object):
     """ Callbacks used when we receive a transaction