diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py
index e8180d94fd..52c84efb5b 100644
--- a/synapse/federation/pdu_codec.py
+++ b/synapse/federation/pdu_codec.py
@@ -18,50 +18,25 @@ from .units import Pdu
import copy
-def decode_event_id(event_id, server_name):
- parts = event_id.split("@")
- if len(parts) < 2:
- return (event_id, server_name)
- else:
- return (parts[0], "".join(parts[1:]))
-
-
-def encode_event_id(pdu_id, origin):
- return "%s@%s" % (pdu_id, origin)
-
-
class PduCodec(object):
def __init__(self, hs):
+ self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.event_factory = hs.get_event_factory()
self.clock = hs.get_clock()
+ self.hs = hs
def event_from_pdu(self, pdu):
kwargs = {}
- kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
- kwargs["room_id"] = pdu.context
- kwargs["etype"] = pdu.pdu_type
- kwargs["prev_events"] = [
- encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
- ]
-
- if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
- kwargs["prev_state"] = encode_event_id(
- pdu.prev_state_id, pdu.prev_state_origin
- )
+ kwargs["etype"] = pdu.type
kwargs.update({
k: v
for k, v in pdu.get_full_dict().items()
if k not in [
- "pdu_id",
- "context",
- "pdu_type",
- "prev_pdus",
- "prev_state_id",
- "prev_state_origin",
+ "type",
]
})
@@ -70,33 +45,10 @@ class PduCodec(object):
def pdu_from_event(self, event):
d = event.get_full_dict()
- d["pdu_id"], d["origin"] = decode_event_id(
- event.event_id, self.server_name
- )
- d["context"] = event.room_id
- d["pdu_type"] = event.type
-
- if hasattr(event, "prev_events"):
- d["prev_pdus"] = [
- decode_event_id(e, self.server_name)
- for e in event.prev_events
- ]
-
- if hasattr(event, "prev_state"):
- d["prev_state_id"], d["prev_state_origin"] = (
- decode_event_id(event.prev_state, self.server_name)
- )
-
- if hasattr(event, "state_key"):
- d["is_state"] = True
-
kwargs = copy.deepcopy(event.unrecognized_keys)
kwargs.update({
k: v for k, v in d.items()
- if k not in ["event_id", "room_id", "type", "prev_events"]
})
- if "origin_server_ts" not in kwargs:
- kwargs["origin_server_ts"] = int(self.clock.time_msec())
-
- return Pdu(**kwargs)
+ pdu = Pdu(**kwargs)
+ return pdu
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 7043fcc504..73dc844d59 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -21,8 +21,6 @@ These actions are mostly only used by the :py:mod:`.replication` module.
from twisted.internet import defer
-from .units import Pdu
-
from synapse.util.logutils import log_function
import json
@@ -32,76 +30,6 @@ import logging
logger = logging.getLogger(__name__)
-class PduActions(object):
- """ Defines persistence actions that relate to handling PDUs.
- """
-
- def __init__(self, datastore):
- self.store = datastore
-
- @log_function
- def mark_as_processed(self, pdu):
- """ Persist the fact that we have fully processed the given `Pdu`
-
- Returns:
- Deferred
- """
- return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
-
- @defer.inlineCallbacks
- @log_function
- def after_transaction(self, transaction_id, destination, origin):
- """ Returns all `Pdu`s that we sent to the given remote home server
- after a given transaction id.
-
- Returns:
- Deferred: Results in a list of `Pdu`s
- """
- results = yield self.store.get_pdus_after_transaction(
- transaction_id,
- destination
- )
-
- defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
- @defer.inlineCallbacks
- @log_function
- def get_all_pdus_from_context(self, context):
- results = yield self.store.get_all_pdus_from_context(context)
- defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
- @defer.inlineCallbacks
- @log_function
- def backfill(self, context, pdu_list, limit):
- """ For a given list of PDU id and origins return the proceeding
- `limit` `Pdu`s in the given `context`.
-
- Returns:
- Deferred: Results in a list of `Pdu`s.
- """
- results = yield self.store.get_backfill(
- context, pdu_list, limit
- )
-
- defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
- @log_function
- def is_new(self, pdu):
- """ When we receive a `Pdu` from a remote home server, we want to
- figure out whether it is `new`, i.e. it is not some historic PDU that
- we haven't seen simply because we haven't backfilled back that far.
-
- Returns:
- Deferred: Results in a `bool`
- """
- return self.store.is_pdu_new(
- pdu_id=pdu.pdu_id,
- origin=pdu.origin,
- context=pdu.context,
- depth=pdu.depth
- )
-
-
class TransactionActions(object):
""" Defines persistence actions that relate to handling Transactions.
"""
@@ -158,7 +86,6 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.destination,
transaction.origin_server_ts,
- [(p["pdu_id"], p["origin"]) for p in transaction.pdus]
)
@log_function
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 092411eaf9..a07e307849 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
from .units import Transaction, Pdu, Edu
-from .persistence import PduActions, TransactionActions
+from .persistence import TransactionActions
from synapse.util.logutils import log_function
@@ -57,7 +57,7 @@ class ReplicationLayer(object):
self.transport_layer.register_request_handler(self)
self.store = hs.get_datastore()
- self.pdu_actions = PduActions(self.store)
+ # self.pdu_actions = PduActions(self.store)
self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = _TransactionQueue(
@@ -81,7 +81,7 @@ class ReplicationLayer(object):
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
- raise KeyError("Already have an EDU handler for %s" % (edu_type))
+ raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
@@ -102,24 +102,17 @@ class ReplicationLayer(object):
object to encode as JSON.
"""
if query_type in self.query_handlers:
- raise KeyError("Already have a Query handler for %s" % (query_type))
+ raise KeyError(
+ "Already have a Query handler for %s" % (query_type,)
+ )
self.query_handlers[query_type] = handler
- @defer.inlineCallbacks
@log_function
def send_pdu(self, pdu):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
- This will fill out various attributes on the PDU object, e.g. the
- `prev_pdus` key.
-
- *Note:* The home server should always call `send_pdu` even if it knows
- that it does not need to be replicated to other home servers. This is
- in case e.g. someone else joins via a remote home server and then
- backfills.
-
TODO: Figure out when we should actually resolve the deferred.
Args:
@@ -132,18 +125,15 @@ class ReplicationLayer(object):
order = self._order
self._order += 1
- logger.debug("[%s] Persisting PDU", pdu.pdu_id)
-
- # Save *before* trying to send
- yield self.store.persist_event(pdu=pdu)
-
- logger.debug("[%s] Persisted PDU", pdu.pdu_id)
- logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
+ logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, order)
- logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
+ logger.debug(
+ "[%s] transaction_layer.enqueue_pdu... done",
+ pdu.event_id
+ )
@log_function
def send_edu(self, destination, edu_type, content):
@@ -159,6 +149,11 @@ class ReplicationLayer(object):
return defer.succeed(None)
@log_function
+ def send_failure(self, failure, destination):
+ self._transaction_queue.enqueue_failure(failure, destination)
+ return defer.succeed(None)
+
+ @log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=True):
"""Sends a federation Query to a remote homeserver of the given type
@@ -181,7 +176,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def backfill(self, dest, context, limit):
+ def backfill(self, dest, context, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
@@ -189,12 +184,12 @@ class ReplicationLayer(object):
dest (str): The remote home server to ask.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
+ extremities (list): List of PDU id and origins of the first pdus
+ we have seen from the context
Returns:
Deferred: Results in the received PDUs.
"""
- extremities = yield self.store.get_oldest_pdus_in_context(context)
-
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
@@ -210,13 +205,13 @@ class ReplicationLayer(object):
pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
for pdu in pdus:
- yield self._handle_new_pdu(pdu, backfilled=True)
+ yield self._handle_new_pdu(dest, pdu, backfilled=True)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
- def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
+ def get_pdu(self, destination, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
server.
@@ -225,7 +220,7 @@ class ReplicationLayer(object):
Args:
destination (str): Which home server to query
pdu_origin (str): The home server that originally sent the pdu.
- pdu_id (str)
+ event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
@@ -234,8 +229,9 @@ class ReplicationLayer(object):
Deferred: Results in the requested PDU.
"""
- transaction_data = yield self.transport_layer.get_pdu(
- destination, pdu_origin, pdu_id)
+ transaction_data = yield self.transport_layer.get_event(
+ destination, event_id
+ )
transaction = Transaction(**transaction_data)
@@ -244,13 +240,13 @@ class ReplicationLayer(object):
pdu = None
if pdu_list:
pdu = pdu_list[0]
- yield self._handle_new_pdu(pdu)
+ yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
- def get_state_for_context(self, destination, context):
+ def get_state_for_context(self, destination, context, event_id=None):
"""Requests all of the `current` state PDUs for a given context from
a remote home server.
@@ -263,29 +259,23 @@ class ReplicationLayer(object):
"""
transaction_data = yield self.transport_layer.get_context_state(
- destination, context)
+ destination,
+ context,
+ event_id=event_id,
+ )
transaction = Transaction(**transaction_data)
pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
- for pdu in pdus:
- yield self._handle_new_pdu(pdu)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
- def on_context_pdus_request(self, context):
- pdus = yield self.pdu_actions.get_all_pdus_from_context(
- context
+ def on_backfill_request(self, origin, context, versions, limit):
+ pdus = yield self.handler.on_backfill_request(
+ origin, context, versions, limit
)
- defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
-
- @defer.inlineCallbacks
- @log_function
- def on_backfill_request(self, context, versions, limit):
-
- pdus = yield self.pdu_actions.backfill(context, versions, limit)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@@ -295,6 +285,10 @@ class ReplicationLayer(object):
transaction = Transaction(**transaction_data)
for p in transaction.pdus:
+ if "unsigned" in p:
+ unsigned = p["unsigned"]
+ if "age" in unsigned:
+ p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
@@ -315,11 +309,15 @@ class ReplicationLayer(object):
dl = []
for pdu in pdu_list:
- dl.append(self._handle_new_pdu(pdu))
+ dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
- self.received_edu(transaction.origin, edu.edu_type, edu.content)
+ self.received_edu(
+ transaction.origin,
+ edu.edu_type,
+ edu.content
+ )
results = yield defer.DeferredList(dl)
@@ -347,20 +345,22 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def on_context_state_request(self, context):
- results = yield self.store.get_current_state_for_context(
- context
- )
-
- logger.debug("Context returning %d results", len(results))
+ def on_context_state_request(self, origin, context, event_id):
+ if event_id:
+ pdus = yield self.handler.get_state_for_pdu(
+ origin,
+ context,
+ event_id,
+ )
+ else:
+ raise NotImplementedError("Specify an event")
- pdus = [Pdu.from_pdu_tuple(p) for p in results]
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
- def on_pdu_request(self, pdu_origin, pdu_id):
- pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
+ def on_pdu_request(self, origin, event_id):
+ pdu = yield self._get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
@@ -372,103 +372,191 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
- transaction_id = max([int(v) for v in versions])
+ raise NotImplementedError("Pull transacions not implemented")
+
+ @defer.inlineCallbacks
+ def on_query_request(self, query_type, args):
+ if query_type in self.query_handlers:
+ response = yield self.query_handlers[query_type](args)
+ defer.returnValue((200, response))
+ else:
+ defer.returnValue(
+ (404, "No handler for Query type '%s'" % (query_type, ))
+ )
+
+ @defer.inlineCallbacks
+ def on_make_join_request(self, context, user_id):
+ pdu = yield self.handler.on_make_join_request(context, user_id)
+ defer.returnValue({
+ "event": pdu.get_dict(),
+ })
- response = yield self.pdu_actions.after_transaction(
- transaction_id,
- origin,
- self.server_name
+ @defer.inlineCallbacks
+ def on_invite_request(self, origin, content):
+ pdu = Pdu(**content)
+ ret_pdu = yield self.handler.on_invite_request(origin, pdu)
+ defer.returnValue(
+ (
+ 200,
+ {
+ "event": ret_pdu.get_dict(),
+ }
+ )
)
- if not response:
- response = []
+ @defer.inlineCallbacks
+ def on_send_join_request(self, origin, content):
+ pdu = Pdu(**content)
+ res_pdus = yield self.handler.on_send_join_request(origin, pdu)
+
+ defer.returnValue((200, {
+ "state": [p.get_dict() for p in res_pdus["state"]],
+ "auth_chain": [p.get_dict() for p in res_pdus["auth_chain"]],
+ }))
+ @defer.inlineCallbacks
+ def on_event_auth(self, origin, context, event_id):
+ auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue(
- (200, self._transaction_from_pdus(response).get_dict())
+ (
+ 200,
+ {
+ "auth_chain": [a.get_dict() for a in auth_pdus],
+ }
+ )
)
@defer.inlineCallbacks
- def on_query_request(self, query_type, args):
- if query_type in self.query_handlers:
- response = yield self.query_handlers[query_type](args)
- defer.returnValue((200, response))
- else:
- defer.returnValue((404, "No handler for Query type '%s'"
- % (query_type)
- ))
+ def make_join(self, destination, context, user_id):
+ ret = yield self.transport_layer.make_join(
+ destination=destination,
+ context=context,
+ user_id=user_id,
+ )
+
+ pdu_dict = ret["event"]
+
+ logger.debug("Got response to make_join: %s", pdu_dict)
+
+ defer.returnValue(Pdu(**pdu_dict))
@defer.inlineCallbacks
+ def send_join(self, destination, pdu):
+ _, content = yield self.transport_layer.send_join(
+ destination,
+ pdu.room_id,
+ pdu.event_id,
+ pdu.get_dict(),
+ )
+
+ logger.debug("Got content: %s", content)
+
+ state = [Pdu(outlier=True, **p) for p in content.get("state", [])]
+
+ # FIXME: We probably want to do something with the auth_chain given
+ # to us
+
+ # auth_chain = [
+ # Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
+ # ]
+
+ defer.returnValue(state)
+
+ @defer.inlineCallbacks
+ def send_invite(self, destination, context, event_id, pdu):
+ code, content = yield self.transport_layer.send_invite(
+ destination=destination,
+ context=context,
+ event_id=event_id,
+ content=pdu.get_dict(),
+ )
+
+ pdu_dict = content["event"]
+
+ logger.debug("Got response to send_invite: %s", pdu_dict)
+
+ defer.returnValue(Pdu(**pdu_dict))
+
@log_function
- def _get_persisted_pdu(self, pdu_id, pdu_origin):
+ def _get_persisted_pdu(self, origin, event_id):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
- pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
-
- defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
+ return self.handler.get_persisted_pdu(origin, event_id)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
pdus = [p.get_dict() for p in pdu_list]
+ time_now = self._clock.time_msec()
for p in pdus:
- if "age_ts" in pdus:
- p["age"] = int(self.clock.time_msec()) - p["age_ts"]
-
+ if "age_ts" in p:
+ age = time_now - p["age_ts"]
+ p.setdefault("unsigned", {})["age"] = int(age)
+ del p["age_ts"]
return Transaction(
origin=self.server_name,
pdus=pdus,
- origin_server_ts=int(self._clock.time_msec()),
+ origin_server_ts=int(time_now),
destination=None,
)
@defer.inlineCallbacks
@log_function
- def _handle_new_pdu(self, pdu, backfilled=False):
+ def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers
- existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
+ existing = yield self._get_persisted_pdu(origin, pdu.event_id)
if existing and (not existing.outlier or pdu.outlier):
- logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
+ logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return
+ state = None
+
# Get missing pdus if necessary.
- is_new = yield self.pdu_actions.is_new(pdu)
- if is_new and not pdu.outlier:
+ if not pdu.outlier:
# We only backfill backwards to the min depth.
- min_depth = yield self.store.get_min_depth_for_context(pdu.context)
+ min_depth = yield self.handler.get_min_depth_for_context(
+ pdu.room_id
+ )
if min_depth and pdu.depth > min_depth:
- for pdu_id, origin in pdu.prev_pdus:
- exists = yield self._get_persisted_pdu(pdu_id, origin)
+ for event_id, hashes in pdu.prev_events:
+ exists = yield self._get_persisted_pdu(origin, event_id)
if not exists:
- logger.debug("Requesting pdu %s %s", pdu_id, origin)
+ logger.debug("Requesting pdu %s", event_id)
try:
yield self.get_pdu(
pdu.origin,
- pdu_id=pdu_id,
- pdu_origin=origin
+ event_id=event_id,
)
- logger.debug("Processed pdu %s %s", pdu_id, origin)
+ logger.debug("Processed pdu %s", event_id)
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
-
- # Persist the Pdu, but don't mark it as processed yet.
- yield self.store.persist_event(pdu=pdu)
+ else:
+ # We need to get the state at this event, since we have reached
+ # a backward extremity edge.
+ state = yield self.get_state_for_context(
+ origin, pdu.room_id, pdu.event_id,
+ )
if not backfilled:
- ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+ ret = yield self.handler.on_receive_pdu(
+ pdu,
+ backfilled=backfilled,
+ state=state,
+ )
else:
ret = None
- yield self.pdu_actions.mark_as_processed(pdu)
+ # yield self.pdu_actions.mark_as_processed(pdu)
defer.returnValue(ret)
@@ -476,14 +564,6 @@ class ReplicationLayer(object):
return "<ReplicationLayer(%s)>" % self.server_name
-class ReplicationHandler(object):
- """This defines the methods that the :py:class:`.ReplicationLayer` will
- use to communicate with the rest of the home server.
- """
- def on_receive_pdu(self, pdu):
- raise NotImplementedError("on_receive_pdu")
-
-
class _TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
a time for a given destination.
@@ -509,6 +589,9 @@ class _TransactionQueue(object):
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = {}
+ # destination -> list of tuple(failure, deferred)
+ self.pending_failures_by_dest = {}
+
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
@@ -562,6 +645,18 @@ class _TransactionQueue(object):
return deferred
@defer.inlineCallbacks
+ def enqueue_failure(self, failure, destination):
+ deferred = defer.Deferred()
+
+ self.pending_failures_by_dest.setdefault(
+ destination, []
+ ).append(
+ (failure, deferred)
+ )
+
+ yield deferred
+
+ @defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
if destination in self.pending_transactions:
@@ -570,8 +665,9 @@ class _TransactionQueue(object):
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
+ pending_failures = self.pending_failures_by_dest.pop(destination, [])
- if not pending_pdus and not pending_edus:
+ if not pending_pdus and not pending_edus and not pending_failures:
return
logger.debug("TX [%s] Attempting new transaction", destination)
@@ -581,7 +677,11 @@ class _TransactionQueue(object):
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
- deferreds = [x[1] for x in pending_pdus + pending_edus]
+ failures = [x[0].get_dict() for x in pending_failures]
+ deferreds = [
+ x[1]
+ for x in pending_pdus + pending_edus + pending_failures
+ ]
try:
self.pending_transactions[destination] = 1
@@ -589,12 +689,13 @@ class _TransactionQueue(object):
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
- origin_server_ts=self._clock.time_msec(),
+ origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
+ pdu_failures=failures,
)
self._next_txn_id += 1
@@ -614,7 +715,9 @@ class _TransactionQueue(object):
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
- p["age"] = now - int(p["age_ts"])
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
return data
code, response = yield self.transport_layer.send_transaction(
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index e7517cac4d..95c40c6c1b 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -72,7 +72,7 @@ class TransportLayer(object):
self.received_handler = None
@log_function
- def get_context_state(self, destination, context):
+ def get_context_state(self, destination, context, event_id=None):
""" Requests all state for a given context (i.e. room) from the
given server.
@@ -89,54 +89,62 @@ class TransportLayer(object):
subpath = "/state/%s/" % context
- return self._do_request_for_transaction(destination, subpath)
+ args = {}
+ if event_id:
+ args["event_id"] = event_id
+
+ return self._do_request_for_transaction(
+ destination, subpath, args=args
+ )
@log_function
- def get_pdu(self, destination, pdu_origin, pdu_id):
+ def get_event(self, destination, event_id):
""" Requests the pdu with give id and origin from the given server.
Args:
destination (str): The host name of the remote home server we want
to get the state from.
- pdu_origin (str): The home server which created the PDU.
- pdu_id (str): The id of the PDU being requested.
+ event_id (str): The id of the event being requested.
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
- logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
- destination, pdu_origin, pdu_id)
+ logger.debug("get_pdu dest=%s, event_id=%s",
+ destination, event_id)
- subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
+ subpath = "/event/%s/" % (event_id, )
return self._do_request_for_transaction(destination, subpath)
@log_function
- def backfill(self, dest, context, pdu_tuples, limit):
+ def backfill(self, dest, context, event_tuples, limit):
""" Requests `limit` previous PDUs in a given context before list of
PDUs.
Args:
dest (str)
context (str)
- pdu_tuples (list)
+ event_tuples (list)
limt (int)
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug(
- "backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s",
- dest, context, repr(pdu_tuples), str(limit)
+ "backfill dest=%s, context=%s, event_tuples=%s, limit=%s",
+ dest, context, repr(event_tuples), str(limit)
)
- if not pdu_tuples:
+ if not event_tuples:
+ # TODO: raise?
return
- subpath = "/backfill/%s/" % context
+ subpath = "/backfill/%s/" % (context,)
- args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
- args["limit"] = limit
+ args = {
+ "v": event_tuples,
+ "limit": limit,
+ }
return self._do_request_for_transaction(
dest,
@@ -198,6 +206,72 @@ class TransportLayer(object):
defer.returnValue(response)
@defer.inlineCallbacks
+ @log_function
+ def make_join(self, destination, context, user_id, retry_on_dns_fail=True):
+ path = PREFIX + "/make_join/%s/%s" % (context, user_id,)
+
+ response = yield self.client.get_json(
+ destination=destination,
+ path=path,
+ retry_on_dns_fail=retry_on_dns_fail,
+ )
+
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_join(self, destination, context, event_id, content):
+ path = PREFIX + "/send_join/%s/%s" % (
+ context,
+ event_id,
+ )
+
+ code, content = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_join", code)
+
+ defer.returnValue(json.loads(content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_invite(self, destination, context, event_id, content):
+ path = PREFIX + "/invite/%s/%s" % (
+ context,
+ event_id,
+ )
+
+ code, content = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_invite", code)
+
+ defer.returnValue(json.loads(content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_event_auth(self, destination, context, event_id):
+ path = PREFIX + "/event_auth/%s/%s" % (
+ context,
+ event_id,
+ )
+
+ response = yield self.client.get_json(
+ destination=destination,
+ path=path,
+ )
+
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
def _authenticate_request(self, request):
json_request = {
"method": request.method,
@@ -210,7 +284,7 @@ class TransportLayer(object):
origin = None
if request.method == "PUT":
- #TODO: Handle other method types? other content types?
+ # TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
@@ -222,11 +296,13 @@ class TransportLayer(object):
try:
params = auth.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
+
def strip_quotes(value):
if value.startswith("\""):
return value[1:-1]
else:
return value
+
origin = strip_quotes(param_dict["origin"])
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
@@ -247,7 +323,7 @@ class TransportLayer(object):
if auth.startswith("X-Matrix"):
(origin, key, sig) = parse_auth_header(auth)
json_request["origin"] = origin
- json_request["signatures"].setdefault(origin,{})[key] = sig
+ json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]:
raise SynapseError(
@@ -313,10 +389,10 @@ class TransportLayer(object):
# data_id pair.
self.server.register_path(
"GET",
- re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
+ re.compile("^" + PREFIX + "/event/([^/]*)/$"),
self._with_authentication(
- lambda origin, content, query, pdu_origin, pdu_id:
- handler.on_pdu_request(pdu_origin, pdu_id)
+ lambda origin, content, query, event_id:
+ handler.on_pdu_request(origin, event_id)
)
)
@@ -326,7 +402,11 @@ class TransportLayer(object):
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, context:
- handler.on_context_state_request(context)
+ handler.on_context_state_request(
+ origin,
+ context,
+ query.get("event_id", [None])[0],
+ )
)
)
@@ -336,28 +416,63 @@ class TransportLayer(object):
self._with_authentication(
lambda origin, content, query, context:
self._on_backfill_request(
- context, query["v"], query["limit"]
+ origin, context, query["v"], query["limit"]
)
)
)
+ # This is when we receive a server-server Query
self.server.register_path(
"GET",
- re.compile("^" + PREFIX + "/context/([^/]*)/$"),
+ re.compile("^" + PREFIX + "/query/([^/]*)$"),
self._with_authentication(
- lambda origin, content, query, context:
- handler.on_context_pdus_request(context)
+ lambda origin, content, query, query_type:
+ handler.on_query_request(
+ query_type, {k: v[0] for k, v in query.items()}
+ )
)
)
- # This is when we receive a server-server Query
self.server.register_path(
"GET",
- re.compile("^" + PREFIX + "/query/([^/]*)$"),
+ re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
self._with_authentication(
- lambda origin, content, query, query_type:
- handler.on_query_request(
- query_type, {k: v[0] for k, v in query.items()}
+ lambda origin, content, query, context, user_id:
+ self._on_make_join_request(
+ origin, content, query, context, user_id
+ )
+ )
+ )
+
+ self.server.register_path(
+ "GET",
+ re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, event_id:
+ handler.on_event_auth(
+ origin, context, event_id,
+ )
+ )
+ )
+
+ self.server.register_path(
+ "PUT",
+ re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, event_id:
+ self._on_send_join_request(
+ origin, content, query,
+ )
+ )
+ )
+
+ self.server.register_path(
+ "PUT",
+ re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, event_id:
+ self._on_invite_request(
+ origin, content, query,
)
)
)
@@ -402,7 +517,8 @@ class TransportLayer(object):
return
try:
- code, response = yield self.received_handler.on_incoming_transaction(
+ handler = self.received_handler
+ code, response = yield handler.on_incoming_transaction(
transaction_data
)
except:
@@ -440,7 +556,7 @@ class TransportLayer(object):
defer.returnValue(data)
@log_function
- def _on_backfill_request(self, context, v_list, limits):
+ def _on_backfill_request(self, origin, context, v_list, limits):
if not limits:
return defer.succeed(
(400, {"error": "Did not include limit param"})
@@ -448,124 +564,34 @@ class TransportLayer(object):
limit = int(limits[-1])
- versions = [v.split(",", 1) for v in v_list]
+ versions = v_list
return self.request_handler.on_backfill_request(
- context, versions, limit)
-
-
-class TransportReceivedHandler(object):
- """ Callbacks used when we receive a transaction
- """
- def on_incoming_transaction(self, transaction):
- """ Called on PUT /send/<transaction_id>, or on response to a request
- that we sent (e.g. a backfill request)
-
- Args:
- transaction (synapse.transaction.Transaction): The transaction that
- was sent to us.
-
- Returns:
- twisted.internet.defer.Deferred: A deferred that gets fired when
- the transaction has finished being processed.
-
- The result should be a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
-
-class TransportRequestHandler(object):
- """ Handlers used when someone want's data from us
- """
- def on_pull_request(self, versions):
- """ Called on GET /pull/?v=...
-
- This is hit when a remote home server wants to get all data
- after a given transaction. Mainly used when a home server comes back
- online and wants to get everything it has missed.
-
- Args:
- versions (list): A list of transaction_ids that should be used to
- determine what PDUs the remote side have not yet seen.
-
- Returns:
- Deferred: Resultsin a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
- def on_pdu_request(self, pdu_origin, pdu_id):
- """ Called on GET /pdu/<pdu_origin>/<pdu_id>/
-
- Someone wants a particular PDU. This PDU may or may not have originated
- from us.
-
- Args:
- pdu_origin (str)
- pdu_id (str)
-
- Returns:
- Deferred: Resultsin a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
- def on_context_state_request(self, context):
- """ Called on GET /state/<context>/
-
- Gets hit when someone wants all the *current* state for a given
- contexts.
-
- Args:
- context (str): The name of the context that we're interested in.
-
- Returns:
- twisted.internet.defer.Deferred: A deferred that gets fired when
- the transaction has finished being processed.
-
- The result should be a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
- def on_backfill_request(self, context, versions, limit):
- """ Called on GET /backfill/<context>/?v=...&limit=...
+ origin, context, versions, limit
+ )
- Gets hit when we want to backfill backwards on a given context from
- the given point.
+ @defer.inlineCallbacks
+ @log_function
+ def _on_make_join_request(self, origin, content, query, context, user_id):
+ content = yield self.request_handler.on_make_join_request(
+ context, user_id,
+ )
+ defer.returnValue((200, content))
- Args:
- context (str): The context to backfill
- versions (list): A list of 2-tuples representing where to backfill
- from, in the form `(pdu_id, origin)`
- limit (int): How many pdus to return.
+ @defer.inlineCallbacks
+ @log_function
+ def _on_send_join_request(self, origin, content, query):
+ content = yield self.request_handler.on_send_join_request(
+ origin, content,
+ )
- Returns:
- Deferred: Results in a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
+ defer.returnValue((200, content))
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
+ @defer.inlineCallbacks
+ @log_function
+ def _on_invite_request(self, origin, content, query):
+ content = yield self.request_handler.on_invite_request(
+ origin, content,
+ )
- def on_query_request(self):
- """ Called on a GET /query/<query_type> request. """
+ defer.returnValue((200, content))
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index b2fb964180..70412439cd 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -20,8 +20,6 @@ server protocol.
from synapse.util.jsonobject import JsonEncodedObject
import logging
-import json
-import copy
logger = logging.getLogger(__name__)
@@ -33,13 +31,13 @@ class Pdu(JsonEncodedObject):
A Pdu can be classified as "state". For a given context, we can efficiently
retrieve all state pdu's that haven't been clobbered. Clobbering is done
- via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
+ via a unique constraint on the tuple (context, type, state_key). A pdu
is a state pdu if `is_state` is True.
Example pdu::
{
- "pdu_id": "78c",
+ "event_id": "$78c:example.com",
"origin_server_ts": 1404835423000,
"origin": "bar",
"prev_ids": [
@@ -52,24 +50,21 @@ class Pdu(JsonEncodedObject):
"""
valid_keys = [
- "pdu_id",
- "context",
+ "event_id",
+ "room_id",
"origin",
"origin_server_ts",
- "pdu_type",
+ "type",
"destinations",
- "transaction_id",
- "prev_pdus",
+ "prev_events",
"depth",
"content",
- "outlier",
- "is_state", # Below this are keys valid only for State Pdus.
- "state_key",
- "power_level",
- "prev_state_id",
- "prev_state_origin",
- "required_power_level",
+ "hashes",
"user_id",
+ "auth_events",
+ "signatures", # Below this are keys valid only for State Pdus.
+ "state_key",
+ "prev_state",
]
internal_keys = [
@@ -79,61 +74,28 @@ class Pdu(JsonEncodedObject):
]
required_keys = [
- "pdu_id",
- "context",
+ "event_id",
+ "room_id",
"origin",
"origin_server_ts",
- "pdu_type",
+ "type",
"content",
]
# TODO: We need to make this properly load content rather than
# just leaving it as a dict. (OR DO WE?!)
- def __init__(self, destinations=[], is_state=False, prev_pdus=[],
- outlier=False, **kwargs):
- if is_state:
- for required_key in ["state_key"]:
- if required_key not in kwargs:
- raise RuntimeError("Key %s is required" % required_key)
-
+ def __init__(self, destinations=[], prev_events=[],
+ outlier=False, hashes={}, signatures={}, **kwargs):
super(Pdu, self).__init__(
destinations=destinations,
- is_state=is_state,
- prev_pdus=prev_pdus,
+ prev_events=prev_events,
outlier=outlier,
+ hashes=hashes,
+ signatures=signatures,
**kwargs
)
- @classmethod
- def from_pdu_tuple(cls, pdu_tuple):
- """ Converts a PduTuple to a Pdu
-
- Args:
- pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
- convert
-
- Returns:
- Pdu
- """
- if pdu_tuple:
- d = copy.copy(pdu_tuple.pdu_entry._asdict())
- d["origin_server_ts"] = d.pop("ts")
-
- d["content"] = json.loads(d["content_json"])
- del d["content_json"]
-
- args = {f: d[f] for f in cls.valid_keys if f in d}
- if "unrecognized_keys" in d and d["unrecognized_keys"]:
- args.update(json.loads(d["unrecognized_keys"]))
-
- return Pdu(
- prev_pdus=pdu_tuple.prev_pdu_list,
- **args
- )
- else:
- return None
-
def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
@@ -160,11 +122,10 @@ class Edu(JsonEncodedObject):
"edu_type",
]
-# TODO: SYN-103: Remove "origin" and "destination" keys.
-# internal_keys = [
-# "origin",
-# "destination",
-# ]
+ internal_keys = [
+ "origin",
+ "destination",
+ ]
class Transaction(JsonEncodedObject):
@@ -193,6 +154,7 @@ class Transaction(JsonEncodedObject):
"edus",
"transaction_id",
"destination",
+ "pdu_failures",
]
internal_keys = [
@@ -229,7 +191,9 @@ class Transaction(JsonEncodedObject):
transaction_id and origin_server_ts keys.
"""
if "origin_server_ts" not in kwargs:
- raise KeyError("Require 'origin_server_ts' to construct a Transaction")
+ raise KeyError(
+ "Require 'origin_server_ts' to construct a Transaction"
+ )
if "transaction_id" not in kwargs:
raise KeyError(
"Require 'transaction_id' to construct a Transaction"
@@ -241,6 +205,3 @@ class Transaction(JsonEncodedObject):
kwargs["pdus"] = [p.get_dict() for p in pdus]
return Transaction(**kwargs)
-
-
-
|