diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bcdcc90a18..40836e0c51 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,19 +17,16 @@
from ._base import BaseHandler
-from synapse.events.utils import prune_event
from synapse.api.errors import (
- AuthError, FederationError, SynapseError, StoreError,
+ AuthError, FederationError, StoreError,
)
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import (
- compute_event_signature, check_event_content_hash,
- add_hashes_and_signatures,
+ compute_event_signature, add_hashes_and_signatures,
)
from synapse.types import UserID
-from syutil.jsonutil import encode_canonical_json
from twisted.internet import defer
@@ -113,33 +110,6 @@ class FederationHandler(BaseHandler):
logger.debug("Processing event: %s", event.event_id)
- redacted_event = prune_event(event)
-
- redacted_pdu_json = redacted_event.get_pdu_json()
- try:
- yield self.keyring.verify_json_for_server(
- event.origin, redacted_pdu_json
- )
- except SynapseError as e:
- logger.warn(
- "Signature check failed for %s redacted to %s",
- encode_canonical_json(pdu.get_pdu_json()),
- encode_canonical_json(redacted_pdu_json),
- )
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=event.event_id,
- )
-
- if not check_event_content_hash(event):
- logger.warn(
- "Event content has been tampered, redacting %s, %s",
- event.event_id, encode_canonical_json(event.get_dict())
- )
- event = redacted_event
-
logger.debug("Event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently
@@ -149,41 +119,20 @@ class FederationHandler(BaseHandler):
event.room_id,
self.server_name
)
- if not is_in_room and not event.internal_metadata.outlier:
+ if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.")
-
- replication = self.replication_layer
-
- if not state:
- state, auth_chain = yield replication.get_state_for_room(
- origin, context=event.room_id, event_id=event.event_id,
- )
-
- if not auth_chain:
- auth_chain = yield replication.get_event_auth(
- origin,
- context=event.room_id,
- event_id=event.event_id,
- )
-
- for e in auth_chain:
- e.internal_metadata.outlier = True
- try:
- yield self._handle_new_event(e, fetch_auth_from=origin)
- except:
- logger.exception(
- "Failed to handle auth event %s",
- e.event_id,
- )
-
current_state = state
- if state:
+ if state and auth_chain is not None:
for e in state:
- logging.info("A :) %r", e)
e.internal_metadata.outlier = True
try:
- yield self._handle_new_event(e)
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+ yield self._handle_new_event(origin, e, auth_events=auth)
except:
logger.exception(
"Failed to handle state event %s",
@@ -192,6 +141,7 @@ class FederationHandler(BaseHandler):
try:
yield self._handle_new_event(
+ origin,
event,
state=state,
backfilled=backfilled,
@@ -394,7 +344,14 @@ class FederationHandler(BaseHandler):
for e in auth_chain:
e.internal_metadata.outlier = True
try:
- yield self._handle_new_event(e)
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+ yield self._handle_new_event(
+ target_host, e, auth_events=auth
+ )
except:
logger.exception(
"Failed to handle auth event %s",
@@ -405,8 +362,13 @@ class FederationHandler(BaseHandler):
# FIXME: Auth these.
e.internal_metadata.outlier = True
try:
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
yield self._handle_new_event(
- e, fetch_auth_from=target_host
+ target_host, e, auth_events=auth
)
except:
logger.exception(
@@ -415,6 +377,7 @@ class FederationHandler(BaseHandler):
)
yield self._handle_new_event(
+ target_host,
new_event,
state=state,
current_state=state,
@@ -481,7 +444,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
- context = yield self._handle_new_event(event)
+ context = yield self._handle_new_event(origin, event)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -682,11 +645,12 @@ class FederationHandler(BaseHandler):
waiters.pop().callback(None)
@defer.inlineCallbacks
- def _handle_new_event(self, event, state=None, backfilled=False,
- current_state=None, fetch_auth_from=None):
+ @log_function
+ def _handle_new_event(self, origin, event, state=None, backfilled=False,
+ current_state=None, auth_events=None):
logger.debug(
- "_handle_new_event: Before annotate: %s, sigs: %s",
+ "_handle_new_event: %s, sigs: %s",
event.event_id, event.signatures,
)
@@ -694,65 +658,44 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
+ if not auth_events:
+ auth_events = context.auth_events
+
logger.debug(
- "_handle_new_event: Before auth fetch: %s, sigs: %s",
- event.event_id, event.signatures,
+ "_handle_new_event: %s, auth_events: %s",
+ event.event_id, auth_events,
)
is_new_state = not event.internal_metadata.is_outlier()
- known_ids = set(
- [s.event_id for s in context.auth_events.values()]
- )
-
- for e_id, _ in event.auth_events:
- if e_id not in known_ids:
- e = yield self.store.get_event(e_id, allow_none=True)
-
- if not e and fetch_auth_from is not None:
- # Grab the auth_chain over federation if we are missing
- # auth events.
- auth_chain = yield self.replication_layer.get_event_auth(
- fetch_auth_from, event.event_id, event.room_id
- )
- for auth_event in auth_chain:
- yield self._handle_new_event(auth_event)
- e = yield self.store.get_event(e_id, allow_none=True)
-
- if not e:
- # TODO: Do some conflict res to make sure that we're
- # not the ones who are wrong.
- logger.info(
- "Rejecting %s as %s not in db or %s",
- event.event_id, e_id, known_ids,
- )
- # FIXME: How does raising AuthError work with federation?
- raise AuthError(403, "Cannot find auth event")
-
- context.auth_events[(e.type, e.state_key)] = e
-
- logger.debug(
- "_handle_new_event: Before hack: %s, sigs: %s",
- event.event_id, event.signatures,
- )
-
+ # This is a hack to fix some old rooms where the initial join event
+ # didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create:
- context.auth_events[(c.type, c.state_key)] = c
+ auth_events[(c.type, c.state_key)] = c
- logger.debug(
- "_handle_new_event: Before auth check: %s, sigs: %s",
- event.event_id, event.signatures,
- )
+ try:
+ yield self.do_auth(
+ origin, event, context, auth_events=auth_events
+ )
+ except AuthError as e:
+ logger.warn(
+ "Rejecting %s because %s",
+ event.event_id, e.msg
+ )
- self.auth.check(event, auth_events=context.auth_events)
+ context.rejected = RejectedReason.AUTH_ERROR
- logger.debug(
- "_handle_new_event: Before persist_event: %s, sigs: %s",
- event.event_id, event.signatures,
- )
+ yield self.store.persist_event(
+ event,
+ context=context,
+ backfilled=backfilled,
+ is_new_state=False,
+ current_state=current_state,
+ )
+ raise
yield self.store.persist_event(
event,
@@ -762,9 +705,257 @@ class FederationHandler(BaseHandler):
current_state=current_state,
)
- logger.debug(
- "_handle_new_event: After persist_event: %s, sigs: %s",
- event.event_id, event.signatures,
+ defer.returnValue(context)
+
+ @defer.inlineCallbacks
+ def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
+ missing):
+ # Just go through and process each event in `remote_auth_chain`. We
+ # don't want to fall into the trap of `missing` being wrong.
+ for e in remote_auth_chain:
+ try:
+ yield self._handle_new_event(origin, e)
+ except AuthError:
+ pass
+
+ # Now get the current auth_chain for the event.
+ local_auth_chain = yield self.store.get_auth_chain([event_id])
+
+ # TODO: Check if we would now reject event_id. If so we need to tell
+ # everyone.
+
+ ret = yield self.construct_auth_difference(
+ local_auth_chain, remote_auth_chain
)
- defer.returnValue(context)
+ logger.debug("on_query_auth reutrning: %s", ret)
+
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ @log_function
+ def do_auth(self, origin, event, context, auth_events):
+ # Check if we have all the auth events.
+ res = yield self.store.have_events(
+ [e_id for e_id, _ in event.auth_events]
+ )
+
+ event_auth_events = set(e_id for e_id, _ in event.auth_events)
+ seen_events = set(res.keys())
+
+ missing_auth = event_auth_events - seen_events
+
+ if missing_auth:
+ logger.debug("Missing auth: %s", missing_auth)
+ # If we don't have all the auth events, we need to get them.
+ remote_auth_chain = yield self.replication_layer.get_event_auth(
+ origin, event.room_id, event.event_id
+ )
+
+ for e in remote_auth_chain:
+ try:
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in remote_auth_chain
+ if e.event_id in auth_ids
+ }
+ e.internal_metadata.outlier = True
+ yield self._handle_new_event(
+ origin, e, auth_events=auth
+ )
+ auth_events[(e.type, e.state_key)] = e
+ except AuthError:
+ pass
+
+ # FIXME: Assumes we have and stored all the state for all the
+ # prev_events
+ current_state = set(e.event_id for e in auth_events.values())
+ different_auth = event_auth_events - current_state
+
+ if different_auth and not event.internal_metadata.is_outlier():
+ # Do auth conflict res.
+ logger.debug("Different auth: %s", different_auth)
+
+ # 1. Get what we think is the auth chain.
+ auth_ids = self.auth.compute_auth_events(event, context)
+ local_auth_chain = yield self.store.get_auth_chain(auth_ids)
+
+ # 2. Get remote difference.
+ result = yield self.replication_layer.query_auth(
+ origin,
+ event.room_id,
+ event.event_id,
+ local_auth_chain,
+ )
+
+ # 3. Process any remote auth chain events we haven't seen.
+ for missing_id in result.get("missing", []):
+ try:
+ for e in result["auth_chain"]:
+ if e.event_id == missing_id:
+ ev = e
+ break
+
+ auth_ids = [e_id for e_id, _ in ev.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in result["auth_chain"]
+ if e.event_id in auth_ids
+ }
+ ev.internal_metadata.outlier = True
+ yield self._handle_new_event(
+ origin, ev, auth_events=auth
+ )
+
+ if ev.event_id in event_auth_events:
+ auth_events[(ev.type, ev.state_key)] = ev
+ except AuthError:
+ pass
+
+ # 4. Look at rejects and their proofs.
+ # TODO.
+
+ context.current_state.update(auth_events)
+ context.state_group = None
+
+ try:
+ self.auth.check(event, auth_events=auth_events)
+ except AuthError:
+ raise
+
+ @defer.inlineCallbacks
+ def construct_auth_difference(self, local_auth, remote_auth):
+ """ Given a local and remote auth chain, find the differences. This
+ assumes that we have already processed all events in remote_auth
+
+ Params:
+ local_auth (list)
+ remote_auth (list)
+
+ Returns:
+ dict
+ """
+
+ logger.debug("construct_auth_difference Start!")
+
+ # TODO: Make sure we are OK with local_auth or remote_auth having more
+ # auth events in them than strictly necessary.
+
+ def sort_fun(ev):
+ return ev.depth, ev.event_id
+
+ logger.debug("construct_auth_difference after sort_fun!")
+
+ # We find the differences by starting at the "bottom" of each list
+ # and iterating up on both lists. The lists are ordered by depth and
+ # then event_id, we iterate up both lists until we find the event ids
+ # don't match. Then we look at depth/event_id to see which side is
+ # missing that event, and iterate only up that list. Repeat.
+
+ remote_list = list(remote_auth)
+ remote_list.sort(key=sort_fun)
+
+ local_list = list(local_auth)
+ local_list.sort(key=sort_fun)
+
+ local_iter = iter(local_list)
+ remote_iter = iter(remote_list)
+
+ logger.debug("construct_auth_difference before get_next!")
+
+ def get_next(it, opt=None):
+ try:
+ return it.next()
+ except:
+ return opt
+
+ current_local = get_next(local_iter)
+ current_remote = get_next(remote_iter)
+
+ logger.debug("construct_auth_difference before while")
+
+ missing_remotes = []
+ missing_locals = []
+ while current_local or current_remote:
+ if current_remote is None:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+ continue
+
+ if current_local is None:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ if current_local.event_id == current_remote.event_id:
+ current_local = get_next(local_iter)
+ current_remote = get_next(remote_iter)
+ continue
+
+ if current_local.depth < current_remote.depth:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+ continue
+
+ if current_local.depth > current_remote.depth:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ # They have the same depth, so we fall back to the event_id order
+ if current_local.event_id < current_remote.event_id:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+
+ if current_local.event_id > current_remote.event_id:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ logger.debug("construct_auth_difference after while")
+
+ # missing locals should be sent to the server
+ # We should find why we are missing remotes, as they will have been
+ # rejected.
+
+ # Remove events from missing_remotes if they are referencing a missing
+ # remote. We only care about the "root" rejected ones.
+ missing_remote_ids = [e.event_id for e in missing_remotes]
+ base_remote_rejected = list(missing_remotes)
+ for e in missing_remotes:
+ for e_id, _ in e.auth_events:
+ if e_id in missing_remote_ids:
+ base_remote_rejected.remove(e)
+
+ reason_map = {}
+
+ for e in base_remote_rejected:
+ reason = yield self.store.get_rejection_reason(e.event_id)
+ if reason is None:
+ # FIXME: ERRR?!
+ logger.warn("Could not find reason for %s", e.event_id)
+ raise RuntimeError("")
+
+ reason_map[e.event_id] = reason
+
+ if reason == RejectedReason.AUTH_ERROR:
+ pass
+ elif reason == RejectedReason.REPLACED:
+ # TODO: Get proof
+ pass
+ elif reason == RejectedReason.NOT_ANCESTOR:
+ # TODO: Get proof.
+ pass
+
+ logger.debug("construct_auth_difference returning")
+
+ defer.returnValue({
+ "auth_chain": local_auth,
+ "rejects": {
+ e.event_id: {
+ "reason": reason_map[e.event_id],
+ "proof": None,
+ }
+ for e in base_remote_rejected
+ },
+ "missing": [e.event_id for e in missing_locals],
+ })
|