diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 9f8aadccca..8abf67b1b5 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -256,31 +256,35 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def get_state_for_context(self, destination, context, event_id=None):
+ def get_state_for_context(self, destination, context, event_id):
"""Requests all of the `current` state PDUs for a given context from
a remote home server.
Args:
destination (str): The remote homeserver to query for the state.
context (str): The context we're interested in.
+ event_id (str): The id of the event we want the state at.
Returns:
Deferred: Results in a list of PDUs.
"""
- transaction_data = yield self.transport_layer.get_context_state(
+ result = yield self.transport_layer.get_context_state(
destination,
context,
event_id=event_id,
)
- transaction = Transaction(**transaction_data)
pdus = [
+ self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+ ]
+
+ auth_chain = [
self.event_from_pdu_json(p, outlier=True)
- for p in transaction.pdus
+ for p in result.get("auth_chain", [])
]
- defer.returnValue(pdus)
+ defer.returnValue((pdus, auth_chain))
@defer.inlineCallbacks
@log_function
@@ -383,10 +387,16 @@ class ReplicationLayer(object):
context,
event_id,
)
+ auth_chain = yield self.store.get_auth_chain(
+ [pdu.event_id for pdu in pdus]
+ )
else:
raise NotImplementedError("Specify an event")
- defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+ defer.returnValue((200, {
+ "pdus": [pdu.get_pdu_json() for pdu in pdus],
+ "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
+ }))
@defer.inlineCallbacks
@log_function
@@ -562,8 +572,8 @@ class ReplicationLayer(object):
already_seen = (
existing and (
- not existing.internal_metadata.outlier
- or pdu.internal_metadata.outlier
+ not existing.internal_metadata.is_outlier()
+ or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
@@ -573,6 +583,8 @@ class ReplicationLayer(object):
state = None
+ auth_chain = []
+
# We need to make sure we have all the auth events.
# for e_id, _ in pdu.auth_events:
# exists = yield self._get_persisted_pdu(
@@ -604,7 +616,7 @@ class ReplicationLayer(object):
# )
# Get missing pdus if necessary.
- if not pdu.internal_metadata.outlier:
+ if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id
@@ -645,7 +657,7 @@ class ReplicationLayer(object):
"_handle_new_pdu getting state for %s",
pdu.room_id
)
- state = yield self.get_state_for_context(
+ state, auth_chain = yield self.get_state_for_context(
origin, pdu.room_id, pdu.event_id,
)
@@ -655,6 +667,7 @@ class ReplicationLayer(object):
pdu,
backfilled=backfilled,
state=state,
+ auth_chain=auth_chain,
)
else:
ret = None
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4aec3563ac..e23c5c2195 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -91,11 +91,12 @@ class FederationHandler(BaseHandler):
yield run_on_reactor()
- yield self.replication_layer.send_pdu(event, destinations)
+ self.replication_layer.send_pdu(event, destinations)
@log_function
@defer.inlineCallbacks
- def on_receive_pdu(self, origin, pdu, backfilled, state=None):
+ def on_receive_pdu(self, origin, pdu, backfilled, state=None,
+ auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler.
"""
@@ -150,40 +151,41 @@ class FederationHandler(BaseHandler):
if not is_in_room and not event.internal_metadata.outlier:
logger.debug("Got event for room we're not in.")
- replication_layer = self.replication_layer
- auth_chain = yield replication_layer.get_event_auth(
- origin,
- context=event.room_id,
- event_id=event.event_id,
- )
+ replication = self.replication_layer
+
+ if not state:
+ state, auth_chain = yield replication.get_state_for_context(
+ origin, context=event.room_id, event_id=event.event_id,
+ )
+
+ if not auth_chain:
+ auth_chain = yield replication.get_event_auth(
+ origin,
+ context=event.room_id,
+ event_id=event.event_id,
+ )
for e in auth_chain:
e.internal_metadata.outlier = True
try:
- yield self._handle_new_event(e, fetch_missing=False)
+ yield self._handle_new_event(e, fetch_auth_from=origin)
except:
logger.exception(
- "Failed to parse auth event %s",
+ "Failed to handle auth event %s",
e.event_id,
)
- if not state:
- state = yield replication_layer.get_state_for_context(
- origin,
- context=event.room_id,
- event_id=event.event_id,
- )
-
current_state = state
if state:
for e in state:
+ logging.info("A :) %r", e)
e.internal_metadata.outlier = True
try:
yield self._handle_new_event(e)
except:
logger.exception(
- "Failed to parse state event %s",
+ "Failed to handle state event %s",
e.event_id,
)
@@ -288,7 +290,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def on_event_auth(self, event_id):
- auth = yield self.store.get_auth_chain(event_id)
+ auth = yield self.store.get_auth_chain([event_id])
for event in auth:
event.signatures.update(
@@ -391,10 +393,10 @@ class FederationHandler(BaseHandler):
for e in auth_chain:
e.internal_metadata.outlier = True
try:
- yield self._handle_new_event(e, fetch_missing=False)
+ yield self._handle_new_event(e)
except:
logger.exception(
- "Failed to parse auth event %s",
+ "Failed to handle auth event %s",
e.event_id,
)
@@ -403,12 +405,11 @@ class FederationHandler(BaseHandler):
e.internal_metadata.outlier = True
try:
yield self._handle_new_event(
- e,
- fetch_missing=True
+ e, fetch_auth_from=target_host
)
except:
logger.exception(
- "Failed to parse state event %s",
+ "Failed to handle state event %s",
e.event_id,
)
@@ -526,9 +527,12 @@ class FederationHandler(BaseHandler):
event.signatures,
)
- yield self.replication_layer.send_pdu(new_pdu, destinations)
+ self.replication_layer.send_pdu(new_pdu, destinations)
- auth_chain = yield self.store.get_auth_chain(event.event_id)
+ state_ids = [e.event_id for e in context.current_state.values()]
+ auth_chain = yield self.store.get_auth_chain(set(
+ [event.event_id] + state_ids
+ ))
defer.returnValue({
"state": context.current_state.values(),
@@ -678,7 +682,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False,
- current_state=None, fetch_missing=True):
+ current_state=None, fetch_auth_from=None):
logger.debug(
"_handle_new_event: Before annotate: %s, sigs: %s",
@@ -699,11 +703,20 @@ class FederationHandler(BaseHandler):
known_ids = set(
[s.event_id for s in context.auth_events.values()]
)
+
for e_id, _ in event.auth_events:
if e_id not in known_ids:
- e = yield self.store.get_event(
- e_id, allow_none=True,
- )
+ e = yield self.store.get_event(e_id, allow_none=True)
+
+ if not e and fetch_auth_from is not None:
+ # Grab the auth_chain over federation if we are missing
+ # auth events.
+ auth_chain = yield self.replication_layer.get_event_auth(
+ fetch_auth_from, event.event_id, event.room_id
+ )
+ for auth_event in auth_chain:
+ yield self._handle_new_event(auth_event)
+ e = yield self.store.get_event(e_id, allow_none=True)
if not e:
# TODO: Do some conflict res to make sure that we're
@@ -713,7 +726,7 @@ class FederationHandler(BaseHandler):
event.event_id, e_id, known_ids,
)
# FIXME: How does raising AuthError work with federation?
- raise AuthError(403, "Auth events are stale")
+ raise AuthError(403, "Cannot find auth event")
context.auth_events[(e.type, e.state_key)] = e
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 60c2d67425..e6bb665932 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -143,9 +143,7 @@ class DataStore(RoomMemberStore, RoomStore,
elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event)
- outlier = False
- if hasattr(event.internal_metadata, "outlier"):
- outlier = event.internal_metadata.outlier
+ outlier = event.internal_metadata.is_outlier()
event_dict = {
k: v
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6dc857c4aa..e0d97f440b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -488,11 +488,13 @@ class SQLBaseStore(object):
ev.unsigned["redacted_because"] = because
if get_prev_content and "replaces_state" in ev.unsigned:
- ev.unsigned["prev_content"] = self._get_event_txn(
+ prev = self._get_event_txn(
txn,
ev.unsigned["replaces_state"],
get_prev_content=False,
- ).get_dict()["content"]
+ )
+ if prev:
+ ev.unsigned["prev_content"] = prev.get_dict()["content"]
return ev
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index ced066f407..fb2eb21713 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -32,39 +32,33 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively.
"""
- def get_auth_chain(self, event_id):
+ def get_auth_chain(self, event_ids):
return self.runInteraction(
"get_auth_chain",
self._get_auth_chain_txn,
- event_id
+ event_ids
)
- def _get_auth_chain_txn(self, txn, event_id):
- results = self._get_auth_chain_ids_txn(txn, event_id)
+ def _get_auth_chain_txn(self, txn, event_ids):
+ results = self._get_auth_chain_ids_txn(txn, event_ids)
- sql = "SELECT * FROM events WHERE event_id = ?"
- rows = []
- for ev_id in results:
- c = txn.execute(sql, (ev_id,))
- rows.extend(self.cursor_to_dict(c))
+ return self._get_events_txn(txn, results)
- return self._parse_events_txn(txn, rows)
-
- def get_auth_chain_ids(self, event_id):
+ def get_auth_chain_ids(self, event_ids):
return self.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
- event_id
+ event_ids
)
- def _get_auth_chain_ids_txn(self, txn, event_id):
+ def _get_auth_chain_ids_txn(self, txn, event_ids):
results = set()
base_sql = (
"SELECT auth_id FROM event_auth WHERE %s"
)
- front = set([event_id])
+ front = set(event_ids)
while front:
sql = base_sql % (
" OR ".join(["event_id=?"] * len(front)),
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 701ccdb781..6925ac96b6 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -115,10 +115,10 @@ class Signal(object):
failure.value,
failure.getTracebackObject()))
if not self.suppress_failures:
- raise failure
+ failure.raiseException()
deferreds.append(d.addErrback(eb))
-
- result = yield defer.DeferredList(
- deferreds, fireOnOneErrback=not self.suppress_failures
- )
- defer.returnValue(result)
+ results = []
+ for deferred in deferreds:
+ result = yield deferred
+ results.append(result)
+ defer.returnValue(results)
|