diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 092411eaf9..d901837d0a 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -244,13 +244,14 @@ class ReplicationLayer(object):
pdu = None
if pdu_list:
pdu = pdu_list[0]
- yield self._handle_new_pdu(pdu)
+ yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
- def get_state_for_context(self, destination, context):
+ def get_state_for_context(self, destination, context, pdu_id=None,
+ pdu_origin=None):
"""Requests all of the `current` state PDUs for a given context from
a remote home server.
@@ -263,13 +264,14 @@ class ReplicationLayer(object):
"""
transaction_data = yield self.transport_layer.get_context_state(
- destination, context)
+ destination, context, pdu_id=pdu_id, pdu_origin=pdu_origin,
+ )
transaction = Transaction(**transaction_data)
pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
for pdu in pdus:
- yield self._handle_new_pdu(pdu)
+ yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdus)
@@ -315,7 +317,7 @@ class ReplicationLayer(object):
dl = []
for pdu in pdu_list:
- dl.append(self._handle_new_pdu(pdu))
+ dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
@@ -347,14 +349,19 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def on_context_state_request(self, context):
- results = yield self.store.get_current_state_for_context(
- context
- )
+ def on_context_state_request(self, context, pdu_id, pdu_origin):
+ if pdu_id and pdu_origin:
+ pdus = yield self.handler.get_state_for_pdu(
+ pdu_id, pdu_origin
+ )
+ else:
+ results = yield self.store.get_current_state_for_context(
+ context
+ )
+ pdus = [Pdu.from_pdu_tuple(p) for p in results]
- logger.debug("Context returning %d results", len(results))
+ logger.debug("Context returning %d results", len(pdus))
- pdus = [Pdu.from_pdu_tuple(p) for p in results]
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@@ -393,9 +400,55 @@ class ReplicationLayer(object):
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
- defer.returnValue((404, "No handler for Query type '%s'"
- % (query_type)
- ))
+ defer.returnValue(
+ (404, "No handler for Query type '%s'" % (query_type, ))
+ )
+
+ @defer.inlineCallbacks
+ def on_make_join_request(self, context, user_id):
+ pdu = yield self.handler.on_make_join_request(context, user_id)
+ defer.returnValue(pdu.get_dict())
+
+ @defer.inlineCallbacks
+ def on_invite_request(self, origin, content):
+ pdu = Pdu(**content)
+ ret_pdu = yield self.handler.on_send_join_request(origin, pdu)
+ defer.returnValue((200, ret_pdu.get_dict()))
+
+ @defer.inlineCallbacks
+ def on_send_join_request(self, origin, content):
+ pdu = Pdu(**content)
+ state = yield self.handler.on_send_join_request(origin, pdu)
+ defer.returnValue((200, self._transaction_from_pdus(state).get_dict()))
+
+ @defer.inlineCallbacks
+ def make_join(self, destination, context, user_id):
+ pdu_dict = yield self.transport_layer.make_join(
+ destination=destination,
+ context=context,
+ user_id=user_id,
+ )
+
+ logger.debug("Got response to make_join: %s", pdu_dict)
+
+ defer.returnValue(Pdu(**pdu_dict))
+
+ @defer.inlineCallbacks
+ def send_join(self, destination, pdu):
+ _, content = yield self.transport_layer.send_join(
+ destination,
+ pdu.context,
+ pdu.pdu_id,
+ pdu.origin,
+ pdu.get_dict(),
+ )
+
+ logger.debug("Got content: %s", content)
+ pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])]
+ for pdu in pdus:
+ yield self._handle_new_pdu(destination, pdu)
+
+ defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
@@ -427,7 +480,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def _handle_new_pdu(self, pdu, backfilled=False):
+ def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
@@ -436,6 +489,8 @@ class ReplicationLayer(object):
defer.returnValue({})
return
+ state = None
+
# Get missing pdus if necessary.
is_new = yield self.pdu_actions.is_new(pdu)
if is_new and not pdu.outlier:
@@ -459,12 +514,22 @@ class ReplicationLayer(object):
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
+ else:
+ # We need to get the state at this event, since we have reached
+ # a backward extremity edge.
+ state = yield self.get_state_for_context(
+ origin, pdu.context, pdu.pdu_id, pdu.origin,
+ )
# Persist the Pdu, but don't mark it as processed yet.
yield self.store.persist_event(pdu=pdu)
if not backfilled:
- ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+ ret = yield self.handler.on_receive_pdu(
+ pdu,
+ backfilled=backfilled,
+ state=state,
+ )
else:
ret = None
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index e7517cac4d..7f01b4faaf 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -72,7 +72,8 @@ class TransportLayer(object):
self.received_handler = None
@log_function
- def get_context_state(self, destination, context):
+ def get_context_state(self, destination, context, pdu_id=None,
+ pdu_origin=None):
""" Requests all state for a given context (i.e. room) from the
given server.
@@ -89,7 +90,14 @@ class TransportLayer(object):
subpath = "/state/%s/" % context
- return self._do_request_for_transaction(destination, subpath)
+ args = {}
+ if pdu_id and pdu_origin:
+ args["pdu_id"] = pdu_id
+ args["pdu_origin"] = pdu_origin
+
+ return self._do_request_for_transaction(
+ destination, subpath, args=args
+ )
@log_function
def get_pdu(self, destination, pdu_origin, pdu_id):
@@ -135,8 +143,10 @@ class TransportLayer(object):
subpath = "/backfill/%s/" % context
- args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
- args["limit"] = limit
+ args = {
+ "v": ["%s,%s" % (i, o) for i, o in pdu_tuples],
+ "limit": limit,
+ }
return self._do_request_for_transaction(
dest,
@@ -198,6 +208,59 @@ class TransportLayer(object):
defer.returnValue(response)
@defer.inlineCallbacks
+ @log_function
+ def make_join(self, destination, context, user_id, retry_on_dns_fail=True):
+ path = PREFIX + "/make_join/%s/%s" % (context, user_id,)
+
+ response = yield self.client.get_json(
+ destination=destination,
+ path=path,
+ retry_on_dns_fail=retry_on_dns_fail,
+ )
+
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_join(self, destination, context, pdu_id, origin, content):
+ path = PREFIX + "/send_join/%s/%s/%s" % (
+ context,
+ origin,
+ pdu_id,
+ )
+
+ code, content = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_join", code)
+
+ defer.returnValue(json.loads(content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_invite(self, destination, context, pdu_id, origin, content):
+ path = PREFIX + "/invite/%s/%s/%s" % (
+ context,
+ origin,
+ pdu_id,
+ )
+
+ code, content = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_invite", code)
+
+ defer.returnValue(json.loads(content))
+
+ @defer.inlineCallbacks
def _authenticate_request(self, request):
json_request = {
"method": request.method,
@@ -326,7 +389,11 @@ class TransportLayer(object):
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, context:
- handler.on_context_state_request(context)
+ handler.on_context_state_request(
+ context,
+ query.get("pdu_id", [None])[0],
+ query.get("pdu_origin", [None])[0]
+ )
)
)
@@ -362,6 +429,39 @@ class TransportLayer(object):
)
)
+ self.server.register_path(
+ "GET",
+ re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, user_id:
+ self._on_make_join_request(
+ origin, content, query, context, user_id
+ )
+ )
+ )
+
+ self.server.register_path(
+ "PUT",
+ re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, pdu_origin, pdu_id:
+ self._on_send_join_request(
+ origin, content, query,
+ )
+ )
+ )
+
+ self.server.register_path(
+ "PUT",
+ re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, pdu_origin, pdu_id:
+ self._on_invite_request(
+ origin, content, query,
+ )
+ )
+ )
+
@defer.inlineCallbacks
@log_function
def _on_send_request(self, origin, content, query, transaction_id):
@@ -451,7 +551,34 @@ class TransportLayer(object):
versions = [v.split(",", 1) for v in v_list]
return self.request_handler.on_backfill_request(
- context, versions, limit)
+ context, versions, limit
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _on_make_join_request(self, origin, content, query, context, user_id):
+ content = yield self.request_handler.on_make_join_request(
+ context, user_id,
+ )
+ defer.returnValue((200, content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def _on_send_join_request(self, origin, content, query):
+ content = yield self.request_handler.on_send_join_request(
+ origin, content,
+ )
+
+ defer.returnValue((200, content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def _on_invite_request(self, origin, content, query):
+ content = yield self.request_handler.on_invite_request(
+ origin, content,
+ )
+
+ defer.returnValue((200, content))
class TransportReceivedHandler(object):
|