diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 092411eaf9..8ee74de005 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -19,9 +19,9 @@ a given transport.
from twisted.internet import defer
-from .units import Transaction, Pdu, Edu
+from .units import Transaction, Edu
-from .persistence import PduActions, TransactionActions
+from .persistence import TransactionActions
from synapse.util.logutils import log_function
@@ -57,7 +57,7 @@ class ReplicationLayer(object):
self.transport_layer.register_request_handler(self)
self.store = hs.get_datastore()
- self.pdu_actions = PduActions(self.store)
+ # self.pdu_actions = PduActions(self.store)
self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = _TransactionQueue(
@@ -72,6 +72,8 @@ class ReplicationLayer(object):
self._clock = hs.get_clock()
+ self.event_factory = hs.get_event_factory()
+
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
@@ -81,7 +83,7 @@ class ReplicationLayer(object):
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
- raise KeyError("Already have an EDU handler for %s" % (edu_type))
+ raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
@@ -102,24 +104,17 @@ class ReplicationLayer(object):
object to encode as JSON.
"""
if query_type in self.query_handlers:
- raise KeyError("Already have a Query handler for %s" % (query_type))
+ raise KeyError(
+ "Already have a Query handler for %s" % (query_type,)
+ )
self.query_handlers[query_type] = handler
- @defer.inlineCallbacks
@log_function
def send_pdu(self, pdu):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
- This will fill out various attributes on the PDU object, e.g. the
- `prev_pdus` key.
-
- *Note:* The home server should always call `send_pdu` even if it knows
- that it does not need to be replicated to other home servers. This is
- in case e.g. someone else joins via a remote home server and then
- backfills.
-
TODO: Figure out when we should actually resolve the deferred.
Args:
@@ -132,18 +127,15 @@ class ReplicationLayer(object):
order = self._order
self._order += 1
- logger.debug("[%s] Persisting PDU", pdu.pdu_id)
-
- # Save *before* trying to send
- yield self.store.persist_event(pdu=pdu)
-
- logger.debug("[%s] Persisted PDU", pdu.pdu_id)
- logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
+ logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, order)
- logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
+ logger.debug(
+ "[%s] transaction_layer.enqueue_pdu... done",
+ pdu.event_id
+ )
@log_function
def send_edu(self, destination, edu_type, content):
@@ -159,6 +151,11 @@ class ReplicationLayer(object):
return defer.succeed(None)
@log_function
+ def send_failure(self, failure, destination):
+ self._transaction_queue.enqueue_failure(failure, destination)
+ return defer.succeed(None)
+
+ @log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=True):
"""Sends a federation Query to a remote homeserver of the given type
@@ -181,7 +178,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def backfill(self, dest, context, limit):
+ def backfill(self, dest, context, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
@@ -189,12 +186,12 @@ class ReplicationLayer(object):
dest (str): The remote home server to ask.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
+ extremities (list): List of PDU id and origins of the first pdus
+ we have seen from the context
Returns:
Deferred: Results in the received PDUs.
"""
- extremities = yield self.store.get_oldest_pdus_in_context(context)
-
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
@@ -208,15 +205,18 @@ class ReplicationLayer(object):
transaction = Transaction(**transaction_data)
- pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
+ pdus = [
+ self.event_from_pdu_json(p, outlier=False)
+ for p in transaction.pdus
+ ]
for pdu in pdus:
- yield self._handle_new_pdu(pdu, backfilled=True)
+ yield self._handle_new_pdu(dest, pdu, backfilled=True)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
- def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
+ def get_pdu(self, destination, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
server.
@@ -225,7 +225,7 @@ class ReplicationLayer(object):
Args:
destination (str): Which home server to query
pdu_origin (str): The home server that originally sent the pdu.
- pdu_id (str)
+ event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
@@ -234,23 +234,27 @@ class ReplicationLayer(object):
Deferred: Results in the requested PDU.
"""
- transaction_data = yield self.transport_layer.get_pdu(
- destination, pdu_origin, pdu_id)
+ transaction_data = yield self.transport_layer.get_event(
+ destination, event_id
+ )
transaction = Transaction(**transaction_data)
- pdu_list = [Pdu(outlier=outlier, **p) for p in transaction.pdus]
+ pdu_list = [
+ self.event_from_pdu_json(p, outlier=outlier)
+ for p in transaction.pdus
+ ]
pdu = None
if pdu_list:
pdu = pdu_list[0]
- yield self._handle_new_pdu(pdu)
+ yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
- def get_state_for_context(self, destination, context):
+ def get_state_for_context(self, destination, context, event_id=None):
"""Requests all of the `current` state PDUs for a given context from
a remote home server.
@@ -263,29 +267,25 @@ class ReplicationLayer(object):
"""
transaction_data = yield self.transport_layer.get_context_state(
- destination, context)
+ destination,
+ context,
+ event_id=event_id,
+ )
transaction = Transaction(**transaction_data)
-
- pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
- for pdu in pdus:
- yield self._handle_new_pdu(pdu)
+ pdus = [
+ self.event_from_pdu_json(p, outlier=True)
+ for p in transaction.pdus
+ ]
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
- def on_context_pdus_request(self, context):
- pdus = yield self.pdu_actions.get_all_pdus_from_context(
- context
+ def on_backfill_request(self, origin, context, versions, limit):
+ pdus = yield self.handler.on_backfill_request(
+ origin, context, versions, limit
)
- defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
-
- @defer.inlineCallbacks
- @log_function
- def on_backfill_request(self, context, versions, limit):
-
- pdus = yield self.pdu_actions.backfill(context, versions, limit)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@@ -295,11 +295,17 @@ class ReplicationLayer(object):
transaction = Transaction(**transaction_data)
for p in transaction.pdus:
+ if "unsigned" in p:
+ unsigned = p["unsigned"]
+ if "age" in unsigned:
+ p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
- pdu_list = [Pdu(**p) for p in transaction.pdus]
+ pdu_list = [
+ self.event_from_pdu_json(p) for p in transaction.pdus
+ ]
logger.debug("[%s] Got transaction", transaction.transaction_id)
@@ -315,11 +321,15 @@ class ReplicationLayer(object):
dl = []
for pdu in pdu_list:
- dl.append(self._handle_new_pdu(pdu))
+ dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
- self.received_edu(transaction.origin, edu.edu_type, edu.content)
+ self.received_edu(
+ transaction.origin,
+ edu.edu_type,
+ edu.content
+ )
results = yield defer.DeferredList(dl)
@@ -347,20 +357,22 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def on_context_state_request(self, context):
- results = yield self.store.get_current_state_for_context(
- context
- )
-
- logger.debug("Context returning %d results", len(results))
+ def on_context_state_request(self, origin, context, event_id):
+ if event_id:
+ pdus = yield self.handler.get_state_for_pdu(
+ origin,
+ context,
+ event_id,
+ )
+ else:
+ raise NotImplementedError("Specify an event")
- pdus = [Pdu.from_pdu_tuple(p) for p in results]
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
- def on_pdu_request(self, pdu_origin, pdu_id):
- pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
+ def on_pdu_request(self, origin, event_id):
+ pdu = yield self._get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
@@ -372,116 +384,207 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
- transaction_id = max([int(v) for v in versions])
+ raise NotImplementedError("Pull transacions not implemented")
- response = yield self.pdu_actions.after_transaction(
- transaction_id,
- origin,
- self.server_name
+ @defer.inlineCallbacks
+ def on_query_request(self, query_type, args):
+ if query_type in self.query_handlers:
+ response = yield self.query_handlers[query_type](args)
+ defer.returnValue((200, response))
+ else:
+ defer.returnValue(
+ (404, "No handler for Query type '%s'" % (query_type, ))
+ )
+
+ @defer.inlineCallbacks
+ def on_make_join_request(self, context, user_id):
+ pdu = yield self.handler.on_make_join_request(context, user_id)
+ defer.returnValue({
+ "event": pdu.get_pdu_json(),
+ })
+
+ @defer.inlineCallbacks
+ def on_invite_request(self, origin, content):
+ pdu = self.event_from_pdu_json(content)
+ ret_pdu = yield self.handler.on_invite_request(origin, pdu)
+ defer.returnValue(
+ (
+ 200,
+ {
+ "event": ret_pdu.get_pdu_json(),
+ }
+ )
)
- if not response:
- response = []
+ @defer.inlineCallbacks
+ def on_send_join_request(self, origin, content):
+ pdu = self.event_from_pdu_json(content)
+ res_pdus = yield self.handler.on_send_join_request(origin, pdu)
+ defer.returnValue((200, {
+ "state": [p.get_pdu_json() for p in res_pdus["state"]],
+ "auth_chain": [p.get_pdu_json() for p in res_pdus["auth_chain"]],
+ }))
+
+ @defer.inlineCallbacks
+ def on_event_auth(self, origin, context, event_id):
+ auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue(
- (200, self._transaction_from_pdus(response).get_dict())
+ (
+ 200,
+ {
+ "auth_chain": [a.get_pdu_json() for a in auth_pdus],
+ }
+ )
)
@defer.inlineCallbacks
- def on_query_request(self, query_type, args):
- if query_type in self.query_handlers:
- response = yield self.query_handlers[query_type](args)
- defer.returnValue((200, response))
- else:
- defer.returnValue((404, "No handler for Query type '%s'"
- % (query_type)
- ))
+ def make_join(self, destination, context, user_id):
+ ret = yield self.transport_layer.make_join(
+ destination=destination,
+ context=context,
+ user_id=user_id,
+ )
+
+ pdu_dict = ret["event"]
+
+ logger.debug("Got response to make_join: %s", pdu_dict)
+
+ defer.returnValue(self.event_from_pdu_json(pdu_dict))
@defer.inlineCallbacks
+ def send_join(self, destination, pdu):
+ _, content = yield self.transport_layer.send_join(
+ destination,
+ pdu.room_id,
+ pdu.event_id,
+ pdu.get_pdu_json(),
+ )
+
+ logger.debug("Got content: %s", content)
+
+ state = [
+ self.event_from_pdu_json(p, outlier=True)
+ for p in content.get("state", [])
+ ]
+
+ # FIXME: We probably want to do something with the auth_chain given
+ # to us
+
+ # auth_chain = [
+ # Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
+ # ]
+
+ defer.returnValue(state)
+
+ @defer.inlineCallbacks
+ def send_invite(self, destination, context, event_id, pdu):
+ code, content = yield self.transport_layer.send_invite(
+ destination=destination,
+ context=context,
+ event_id=event_id,
+ content=pdu.get_pdu_json(),
+ )
+
+ pdu_dict = content["event"]
+
+ logger.debug("Got response to send_invite: %s", pdu_dict)
+
+ defer.returnValue(self.event_from_pdu_json(pdu_dict))
+
@log_function
- def _get_persisted_pdu(self, pdu_id, pdu_origin):
+ def _get_persisted_pdu(self, origin, event_id):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
- pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
-
- defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
+ return self.handler.get_persisted_pdu(origin, event_id)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
- pdus = [p.get_dict() for p in pdu_list]
+ pdus = [p.get_pdu_json() for p in pdu_list]
+ time_now = self._clock.time_msec()
for p in pdus:
- if "age_ts" in pdus:
- p["age"] = int(self.clock.time_msec()) - p["age_ts"]
-
+ if "age_ts" in p:
+ age = time_now - p["age_ts"]
+ p.setdefault("unsigned", {})["age"] = int(age)
+ del p["age_ts"]
return Transaction(
origin=self.server_name,
pdus=pdus,
- origin_server_ts=int(self._clock.time_msec()),
+ origin_server_ts=int(time_now),
destination=None,
)
@defer.inlineCallbacks
@log_function
- def _handle_new_pdu(self, pdu, backfilled=False):
+ def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers
- existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
+ existing = yield self._get_persisted_pdu(origin, pdu.event_id)
if existing and (not existing.outlier or pdu.outlier):
- logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
+ logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return
+ state = None
+
# Get missing pdus if necessary.
- is_new = yield self.pdu_actions.is_new(pdu)
- if is_new and not pdu.outlier:
+ if not pdu.outlier:
# We only backfill backwards to the min depth.
- min_depth = yield self.store.get_min_depth_for_context(pdu.context)
+ min_depth = yield self.handler.get_min_depth_for_context(
+ pdu.room_id
+ )
if min_depth and pdu.depth > min_depth:
- for pdu_id, origin in pdu.prev_pdus:
- exists = yield self._get_persisted_pdu(pdu_id, origin)
+ for event_id, hashes in pdu.prev_events:
+ exists = yield self._get_persisted_pdu(origin, event_id)
if not exists:
- logger.debug("Requesting pdu %s %s", pdu_id, origin)
+ logger.debug("Requesting pdu %s", event_id)
try:
yield self.get_pdu(
pdu.origin,
- pdu_id=pdu_id,
- pdu_origin=origin
+ event_id=event_id,
)
- logger.debug("Processed pdu %s %s", pdu_id, origin)
+ logger.debug("Processed pdu %s", event_id)
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
-
- # Persist the Pdu, but don't mark it as processed yet.
- yield self.store.persist_event(pdu=pdu)
+ else:
+ # We need to get the state at this event, since we have reached
+ # a backward extremity edge.
+ state = yield self.get_state_for_context(
+ origin, pdu.room_id, pdu.event_id,
+ )
if not backfilled:
- ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+ ret = yield self.handler.on_receive_pdu(
+ pdu,
+ backfilled=backfilled,
+ state=state,
+ )
else:
ret = None
- yield self.pdu_actions.mark_as_processed(pdu)
+ # yield self.pdu_actions.mark_as_processed(pdu)
defer.returnValue(ret)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
-
-class ReplicationHandler(object):
- """This defines the methods that the :py:class:`.ReplicationLayer` will
- use to communicate with the rest of the home server.
- """
- def on_receive_pdu(self, pdu):
- raise NotImplementedError("on_receive_pdu")
+ def event_from_pdu_json(self, pdu_json, outlier=False):
+ #TODO: Check we have all the PDU keys here
+ pdu_json.setdefault("hashes", {})
+ pdu_json.setdefault("signatures", {})
+ return self.event_factory.create_event(
+ pdu_json["type"], outlier=outlier, **pdu_json
+ )
class _TransactionQueue(object):
@@ -509,6 +612,9 @@ class _TransactionQueue(object):
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = {}
+ # destination -> list of tuple(failure, deferred)
+ self.pending_failures_by_dest = {}
+
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
@@ -562,6 +668,18 @@ class _TransactionQueue(object):
return deferred
@defer.inlineCallbacks
+ def enqueue_failure(self, failure, destination):
+ deferred = defer.Deferred()
+
+ self.pending_failures_by_dest.setdefault(
+ destination, []
+ ).append(
+ (failure, deferred)
+ )
+
+ yield deferred
+
+ @defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
if destination in self.pending_transactions:
@@ -570,8 +688,9 @@ class _TransactionQueue(object):
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
+ pending_failures = self.pending_failures_by_dest.pop(destination, [])
- if not pending_pdus and not pending_edus:
+ if not pending_pdus and not pending_edus and not pending_failures:
return
logger.debug("TX [%s] Attempting new transaction", destination)
@@ -581,7 +700,11 @@ class _TransactionQueue(object):
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
- deferreds = [x[1] for x in pending_pdus + pending_edus]
+ failures = [x[0].get_dict() for x in pending_failures]
+ deferreds = [
+ x[1]
+ for x in pending_pdus + pending_edus + pending_failures
+ ]
try:
self.pending_transactions[destination] = 1
@@ -589,12 +712,13 @@ class _TransactionQueue(object):
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
- origin_server_ts=self._clock.time_msec(),
+ origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
+ pdu_failures=failures,
)
self._next_txn_id += 1
@@ -614,7 +738,9 @@ class _TransactionQueue(object):
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
- p["age"] = now - int(p["age_ts"])
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
return data
code, response = yield self.transport_layer.send_transaction(
|