diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bc26921768..abe02907b9 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -64,8 +64,7 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
@@ -240,7 +239,6 @@ class FederationHandler(BaseHandler):
return None
state = None
- auth_chain = []
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
@@ -346,7 +344,6 @@ class FederationHandler(BaseHandler):
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
- auth_chains = set()
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
@@ -370,38 +367,14 @@ class FederationHandler(BaseHandler):
p,
)
- room_version = yield self.store.get_room_version(room_id)
-
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- (
- remote_state,
- got_auth_chain,
- ) = yield self.federation_client.get_state_for_room(
- origin, room_id, p
- )
-
- # we want the state *after* p; get_state_for_room returns the
- # state *before* p.
- remote_event = yield self.federation_client.get_pdu(
- [origin], p, room_version, outlier=True
+ (remote_state, _,) = yield self._get_state_for_room(
+ origin, room_id, p, include_event_in_state=True
)
- if remote_event is None:
- raise Exception(
- "Unable to get missing prev_event %s" % (p,)
- )
-
- if remote_event.is_state():
- remote_state.append(remote_event)
-
- # XXX hrm I'm not convinced that duplicate events will compare
- # for equality, so I'm not sure this does what the author
- # hoped.
- auth_chains.update(got_auth_chain)
-
remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
@@ -410,7 +383,9 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
+ room_version = yield self.store.get_room_version(room_id)
state_map = yield resolve_events_with_store(
+ room_id,
room_version,
state_maps,
event_map,
@@ -430,7 +405,6 @@ class FederationHandler(BaseHandler):
event_map.update(evs)
state = [event_map[e] for e in six.itervalues(state_map)]
- auth_chain = list(auth_chains)
except Exception:
logger.warning(
"[%s %s] Error attempting to resolve state at missing "
@@ -446,9 +420,7 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
- yield self._process_received_pdu(
- origin, pdu, state=state, auth_chain=auth_chain
- )
+ yield self._process_received_pdu(origin, pdu, state=state)
@defer.inlineCallbacks
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
@@ -584,49 +556,149 @@ class FederationHandler(BaseHandler):
raise
@defer.inlineCallbacks
- def _process_received_pdu(self, origin, event, state, auth_chain):
- """ Called when we have a new pdu. We need to do auth checks and put it
- through the StateHandler.
+ @log_function
+ def _get_state_for_room(
+ self, destination, room_id, event_id, include_event_in_state
+ ):
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination (str): The remote homeserver to query for the state.
+ room_id (str): The id of the room we're interested in.
+ event_id (str): The id of the event we want the state at.
+ include_event_in_state: if true, the event itself will be included in the
+ returned state event list.
+
+ Returns:
+ Deferred[Tuple[List[EventBase], List[EventBase]]]:
+ A list of events in the state, and a list of events in the auth chain
+ for the given event.
"""
- room_id = event.room_id
- event_id = event.event_id
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = yield self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
- logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
+ desired_events = set(state_event_ids + auth_event_ids)
+
+ if include_event_in_state:
+ desired_events.add(event_id)
+
+ event_map = yield self._get_events_from_store_or_dest(
+ destination, room_id, desired_events
+ )
- event_ids = set()
- if state:
- event_ids |= {e.event_id for e in state}
- if auth_chain:
- event_ids |= {e.event_id for e in auth_chain}
+ failed_to_fetch = desired_events - event_map.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state/auth events for %s: %s",
+ room_id,
+ failed_to_fetch,
+ )
- seen_ids = yield self.store.have_seen_events(event_ids)
+ remote_state = [
+ event_map[e_id] for e_id in state_event_ids if e_id in event_map
+ ]
- if state and auth_chain is not None:
- # If we have any state or auth_chain given to us by the replication
- # layer, then we should handle them (if we haven't before.)
+ if include_event_in_state:
+ remote_event = event_map.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+ if remote_event.is_state() and remote_event.rejected_reason is None:
+ remote_state.append(remote_event)
- event_infos = []
+ auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+ auth_chain.sort(key=lambda e: e.depth)
- for e in itertools.chain(auth_chain, state):
- if e.event_id in seen_ids:
- continue
- e.internal_metadata.outlier = True
- auth_ids = e.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in auth_chain
- if e.event_id in auth_ids or e.type == EventTypes.Create
- }
- event_infos.append(_NewEventInfo(event=e, auth_events=auth))
- seen_ids.add(e.event_id)
+ return remote_state, auth_chain
- logger.info(
- "[%s %s] persisting newly-received auth/state events %s",
+ @defer.inlineCallbacks
+ def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+ """Fetch events from a remote destination, checking if we already have them.
+
+ Args:
+ destination (str)
+ room_id (str)
+ event_ids (Iterable[str])
+
+ Persists any events we don't already have as outliers.
+
+ If we fail to fetch any of the events, a warning will be logged, and the event
+ will be omitted from the result. Likewise, any events which turn out not to
+ be in the given room.
+
+ Returns:
+ Deferred[dict[str, EventBase]]: A deferred resolving to a map
+ from event_id to event
+ """
+ fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+
+ missing_events = set(event_ids) - fetched_events.keys()
+
+ if missing_events:
+ logger.debug(
+ "Fetching unknown state/auth events %s for room %s",
+ missing_events,
room_id,
- event_id,
- [e.event.event_id for e in event_infos],
)
- yield self._handle_new_events(origin, event_infos)
+
+ yield self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
+ )
+
+ # we need to make sure we re-load from the database to get the rejected
+ # state correct.
+ fetched_events.update(
+ (yield self.store.get_events(missing_events, allow_rejected=True))
+ )
+
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = list(
+ (event_id, event.room_id)
+ for event_id, event in fetched_events.items()
+ if event.room_id != room_id
+ )
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned auth/state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+ del fetched_events[bad_event_id]
+
+ return fetched_events
+
+ @defer.inlineCallbacks
+ def _process_received_pdu(self, origin, event, state):
+ """ Called when we have a new pdu. We need to do auth checks and put it
+ through the StateHandler.
+
+ Args:
+ origin: server sending the event
+
+ event: event to be persisted
+
+ state: Normally None, but if we are handling a gap in the graph
+ (ie, we are missing one or more prev_events), the resolved state at the
+ event
+ """
+ room_id = event.room_id
+ event_id = event.event_id
+
+ logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
try:
context = yield self._handle_new_event(origin, event, state=state)
@@ -683,8 +755,6 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
- room_version = yield self.store.get_room_version(room_id)
-
events = yield self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
@@ -713,6 +783,9 @@ class FederationHandler(BaseHandler):
event_ids = set(e.event_id for e in events)
+ # build a list of events whose prev_events weren't in the batch.
+ # (XXX: this will include events whose prev_events we already have; that doesn't
+ # sound right?)
edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
@@ -723,7 +796,7 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = yield self.federation_client.get_state_for_room(
+ state, auth = yield self._get_state_for_room(
destination=dest, room_id=room_id, event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
@@ -741,95 +814,11 @@ class FederationHandler(BaseHandler):
auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
- missing_auth = required_auth - set(auth_events)
- failed_to_fetch = set()
-
- # Try and fetch any missing auth events from both DB and remote servers.
- # We repeatedly do this until we stop finding new auth events.
- while missing_auth - failed_to_fetch:
- logger.info("Missing auth for backfill: %r", missing_auth)
- ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
- auth_events.update(ret_events)
-
- required_auth.update(
- a_id for event in ret_events.values() for a_id in event.auth_event_ids()
- )
- missing_auth = required_auth - set(auth_events)
-
- if missing_auth - failed_to_fetch:
- logger.info(
- "Fetching missing auth for backfill: %r",
- missing_auth - failed_to_fetch,
- )
-
- results = yield make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self.federation_client.get_pdu,
- [dest],
- event_id,
- room_version=room_version,
- outlier=True,
- timeout=10000,
- )
- for event_id in missing_auth - failed_to_fetch
- ],
- consumeErrors=True,
- )
- ).addErrback(unwrapFirstError)
- auth_events.update({a.event_id: a for a in results if a})
- required_auth.update(
- a_id
- for event in results
- if event
- for a_id in event.auth_event_ids()
- )
- missing_auth = required_auth - set(auth_events)
-
- failed_to_fetch = missing_auth - set(auth_events)
-
- seen_events = yield self.store.have_seen_events(
- set(auth_events.keys()) | set(state_events.keys())
- )
-
- # We now have a chunk of events plus associated state and auth chain to
- # persist. We do the persistence in two steps:
- # 1. Auth events and state get persisted as outliers, plus the
- # backward extremities get persisted (as non-outliers).
- # 2. The rest of the events in the chunk get persisted one by one, as
- # each one depends on the previous event for its state.
- #
- # The important thing is that events in the chunk get persisted as
- # non-outliers, including when those events are also in the state or
- # auth chain. Caution must therefore be taken to ensure that they are
- # not accidentally marked as outliers.
- # Step 1a: persist auth events that *don't* appear in the chunk
ev_infos = []
- for a in auth_events.values():
- # We only want to persist auth events as outliers that we haven't
- # seen and aren't about to persist as part of the backfilled chunk.
- if a.event_id in seen_events or a.event_id in event_map:
- continue
- a.internal_metadata.outlier = True
- ev_infos.append(
- _NewEventInfo(
- event=a,
- auth_events={
- (
- auth_events[a_id].type,
- auth_events[a_id].state_key,
- ): auth_events[a_id]
- for a_id in a.auth_event_ids()
- if a_id in auth_events
- },
- )
- )
-
- # Step 1b: persist the events in the chunk we fetched state for (i.e.
- # the backwards extremities) as non-outliers.
+ # Step 1: persist the events in the chunk we fetched state for (i.e.
+ # the backwards extremities), with custom auth events and state
for e_id in events_to_state:
# For paranoia we ensure that these events are marked as
# non-outliers
@@ -1071,6 +1060,57 @@ class FederationHandler(BaseHandler):
return False
+ @defer.inlineCallbacks
+ def _get_events_and_persist(
+ self, destination: str, room_id: str, events: Iterable[str]
+ ):
+ """Fetch the given events from a server, and persist them as outliers.
+
+ Logs a warning if we can't find the given event.
+ """
+
+ room_version = yield self.store.get_room_version(room_id)
+
+ event_infos = []
+
+ async def get_event(event_id: str):
+ with nested_logging_context(event_id):
+ try:
+ event = await self.federation_client.get_pdu(
+ [destination], event_id, room_version, outlier=True,
+ )
+ if event is None:
+ logger.warning(
+ "Server %s didn't return event %s", destination, event_id,
+ )
+ return
+
+ # recursively fetch the auth events for this event
+ auth_events = await self._get_events_from_store_or_dest(
+ destination, room_id, event.auth_event_ids()
+ )
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = auth_events.get(auth_event_id)
+ if ae:
+ auth[(ae.type, ae.state_key)] = ae
+
+ event_infos.append(_NewEventInfo(event, None, auth))
+
+ except Exception as e:
+ logger.warning(
+ "Error fetching missing state/auth event %s: %s %s",
+ event_id,
+ type(e),
+ e,
+ )
+
+ yield concurrently_execute(get_event, events, 5)
+
+ yield self._handle_new_events(
+ destination, event_infos,
+ )
+
def _sanity_check_event(self, ev):
"""
Do some early sanity checks of a received event
|