diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0e904f2da0..d4f9a792fc 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,11 +19,13 @@
import itertools
import logging
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six
from six import iteritems, itervalues
from six.moves import http_client, zip
+import attr
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@@ -45,6 +47,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event
+from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.logging.context import (
@@ -60,9 +63,9 @@ from synapse.replication.http.federation import (
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
-from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.types import StateMap, UserID, get_domain_from_id
+from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
@@ -72,6 +75,23 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
+@attr.s
+class _NewEventInfo:
+ """Holds information about a received event, ready for passing to _handle_new_events
+
+ Attributes:
+ event: the received event
+
+ state: the state at that event
+
+ auth_events: the auth_event map for that event
+ """
+
+ event = attr.ib(type=EventBase)
+ state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
+ auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
+
+
def shortstr(iterable, maxitems=5):
"""If iterable has maxitems or fewer, return the stringification of a list
containing those items.
@@ -121,6 +141,7 @@ class FederationHandler(BaseHandler):
self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
+ self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_simple_http_client()
@@ -141,8 +162,9 @@ class FederationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
- @defer.inlineCallbacks
- def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+ self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
+ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -152,17 +174,15 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
-
- Returns (Deferred): completes with None
"""
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+ logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers
- existing = yield self.store.get_event(
+ existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
@@ -206,7 +226,7 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+ is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -217,25 +237,24 @@ class FederationHandler(BaseHandler):
return None
state = None
- auth_chain = []
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
- min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
- if min_depth and pdu.depth < min_depth:
+ if min_depth is not None and pdu.depth < min_depth:
# This is so that we don't notify the user about this
# message, to work around the fact that some events will
# reference really really old events we really don't want to
# send to the clients.
pdu.internal_metadata.outlier = True
- elif min_depth and pdu.depth > min_depth:
+ elif min_depth is not None and pdu.depth > min_depth:
missing_prevs = prevs - seen
if sent_to_us_directly and missing_prevs:
# If we're missing stuff, ensure we only fetch stuff one
@@ -247,7 +266,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
shortstr(missing_prevs),
)
- with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id,
@@ -255,13 +274,19 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
)
- yield self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ )
# Update the set of things we've seen after trying to
# fetch the missing stuff
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@@ -269,14 +294,6 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
- elif missing_prevs:
- logger.info(
- "[%s %s] Not recursively fetching %d missing prev_events: %s",
- room_id,
- event_id,
- len(missing_prevs),
- shortstr(missing_prevs),
- )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -321,18 +338,21 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
+ logger.info(
+ "Event %s is missing prev_events: calculating state for a "
+ "backwards extremity",
+ event_id,
+ )
+
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
- auth_chains = set()
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
- ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
- state_maps = list(
- ours.values()
- ) # type: list[dict[tuple[str, str], str]]
+ state_maps = list(ours.values()) # type: list[StateMap[str]]
# we don't need this any more, let's delete it.
del ours
@@ -341,44 +361,17 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "[%s %s] Requesting state at missing prev_event %s",
- room_id,
- event_id,
- p,
+ "Requesting state at missing prev_event %s", event_id,
)
- room_version = yield self.store.get_room_version(room_id)
-
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- (
- remote_state,
- got_auth_chain,
- ) = yield self.federation_client.get_state_for_room(
- origin, room_id, p
- )
-
- # we want the state *after* p; get_state_for_room returns the
- # state *before* p.
- remote_event = yield self.federation_client.get_pdu(
- [origin], p, room_version, outlier=True
+ (remote_state, _,) = await self._get_state_for_room(
+ origin, room_id, p, include_event_in_state=True
)
- if remote_event is None:
- raise Exception(
- "Unable to get missing prev_event %s" % (p,)
- )
-
- if remote_event.is_state():
- remote_state.append(remote_event)
-
- # XXX hrm I'm not convinced that duplicate events will compare
- # for equality, so I'm not sure this does what the author
- # hoped.
- auth_chains.update(got_auth_chain)
-
remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
@@ -387,7 +380,9 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
- state_map = yield resolve_events_with_store(
+ room_version = await self.store.get_room_version(room_id)
+ state_map = await resolve_events_with_store(
+ room_id,
room_version,
state_maps,
event_map,
@@ -399,15 +394,14 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
- evs = yield self.store.get_events(
+ evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
)
event_map.update(evs)
state = [event_map[e] for e in six.itervalues(state_map)]
- auth_chain = list(auth_chains)
except Exception:
logger.warning(
"[%s %s] Error attempting to resolve state at missing "
@@ -423,12 +417,9 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
- yield self._process_received_pdu(
- origin, pdu, state=state, auth_chain=auth_chain
- )
+ await self._process_received_pdu(origin, pdu, state=state)
- @defer.inlineCallbacks
- def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
@@ -440,12 +431,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
return
- latest = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@@ -509,7 +500,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens.
try:
- missing_events = yield self.federation_client.get_missing_events(
+ missing_events = await self.federation_client.get_missing_events(
origin,
room_id,
earliest_events_ids=list(latest),
@@ -548,7 +539,7 @@ class FederationHandler(BaseHandler):
)
with nested_logging_context(ev.event_id):
try:
- yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+ await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warning(
@@ -560,61 +551,159 @@ class FederationHandler(BaseHandler):
else:
raise
- @defer.inlineCallbacks
- def _process_received_pdu(self, origin, event, state, auth_chain):
- """ Called when we have a new pdu. We need to do auth checks and put it
- through the StateHandler.
+ async def _get_state_for_room(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ include_event_in_state: bool = False,
+ ) -> Tuple[List[EventBase], List[EventBase]]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+ include_event_in_state: if true, the event itself will be included in the
+ returned state event list.
+
+ Returns:
+ A list of events in the state, possibly including the event itself, and
+ a list of events in the auth chain for the given event.
"""
- room_id = event.room_id
- event_id = event.event_id
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
- logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
+ desired_events = set(state_event_ids + auth_event_ids)
+
+ if include_event_in_state:
+ desired_events.add(event_id)
+
+ event_map = await self._get_events_from_store_or_dest(
+ destination, room_id, desired_events
+ )
+
+ failed_to_fetch = desired_events - event_map.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state/auth events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
- event_ids = set()
- if state:
- event_ids |= {e.event_id for e in state}
- if auth_chain:
- event_ids |= {e.event_id for e in auth_chain}
+ remote_state = [
+ event_map[e_id] for e_id in state_event_ids if e_id in event_map
+ ]
- seen_ids = yield self.store.have_seen_events(event_ids)
+ if include_event_in_state:
+ remote_event = event_map.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+ if remote_event.is_state() and remote_event.rejected_reason is None:
+ remote_state.append(remote_event)
- if state and auth_chain is not None:
- # If we have any state or auth_chain given to us by the replication
- # layer, then we should handle them (if we haven't before.)
+ auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+ auth_chain.sort(key=lambda e: e.depth)
- event_infos = []
+ return remote_state, auth_chain
- for e in itertools.chain(auth_chain, state):
- if e.event_id in seen_ids:
- continue
- e.internal_metadata.outlier = True
- auth_ids = e.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in auth_chain
- if e.event_id in auth_ids or e.type == EventTypes.Create
- }
- event_infos.append({"event": e, "auth_events": auth})
- seen_ids.add(e.event_id)
+ async def _get_events_from_store_or_dest(
+ self, destination: str, room_id: str, event_ids: Iterable[str]
+ ) -> Dict[str, EventBase]:
+ """Fetch events from a remote destination, checking if we already have them.
- logger.info(
- "[%s %s] persisting newly-received auth/state events %s",
+ Persists any events we don't already have as outliers.
+
+ If we fail to fetch any of the events, a warning will be logged, and the event
+ will be omitted from the result. Likewise, any events which turn out not to
+ be in the given room.
+
+ Returns:
+ map from event_id to event
+ """
+ fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
+
+ missing_events = set(event_ids) - fetched_events.keys()
+
+ if missing_events:
+ logger.debug(
+ "Fetching unknown state/auth events %s for room %s",
+ missing_events,
+ room_id,
+ )
+
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
+ )
+
+ # we need to make sure we re-load from the database to get the rejected
+ # state correct.
+ fetched_events.update(
+ (await self.store.get_events(missing_events, allow_rejected=True))
+ )
+
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = list(
+ (event_id, event.room_id)
+ for event_id, event in fetched_events.items()
+ if event.room_id != room_id
+ )
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned auth/state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
room_id,
- event_id,
- [e["event"].event_id for e in event_infos],
)
- yield self._handle_new_events(origin, event_infos)
+
+ del fetched_events[bad_event_id]
+
+ return fetched_events
+
+ async def _process_received_pdu(
+ self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
+ ):
+ """ Called when we have a new pdu. We need to do auth checks and put it
+ through the StateHandler.
+
+ Args:
+ origin: server sending the event
+
+ event: event to be persisted
+
+ state: Normally None, but if we are handling a gap in the graph
+ (ie, we are missing one or more prev_events), the resolved state at the
+ event
+ """
+ room_id = event.room_id
+ event_id = event.event_id
+
+ logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
try:
- context = yield self._handle_new_event(origin, event, state=state)
+ context = await self._handle_new_event(origin, event, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
- room = yield self.store.get_room(room_id)
+ room = await self.store.get_room(room_id)
if not room:
try:
- yield self.store.store_room(
+ await self.store.store_room(
room_id=room_id, room_creator_user_id="", is_public=False
)
except StoreError:
@@ -627,11 +716,11 @@ class FederationHandler(BaseHandler):
# changing their profile info.
newly_joined = True
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = await context.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
- prev_state = yield self.store.get_event(
+ prev_state = await self.store.get_event(
prev_state_id, allow_none=True
)
if prev_state and prev_state.membership == Membership.JOIN:
@@ -639,11 +728,10 @@ class FederationHandler(BaseHandler):
if newly_joined:
user = UserID.from_string(event.state_key)
- yield self.user_joined_room(user, room_id)
+ await self.user_joined_room(user, room_id)
@log_function
- @defer.inlineCallbacks
- def backfill(self, dest, room_id, limit, extremities):
+ async def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -660,9 +748,7 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
- room_version = yield self.store.get_room_version(room_id)
-
- events = yield self.federation_client.backfill(
+ events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
@@ -677,7 +763,7 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev)
# Don't bother processing events we already have.
- seen_events = yield self.store.have_events_in_timeline(
+ seen_events = await self.store.have_events_in_timeline(
set(e.event_id for e in events)
)
@@ -690,6 +776,9 @@ class FederationHandler(BaseHandler):
event_ids = set(e.event_id for e in events)
+ # build a list of events whose prev_events weren't in the batch.
+ # (XXX: this will include events whose prev_events we already have; that doesn't
+ # sound right?)
edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
@@ -700,8 +789,11 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = yield self.federation_client.get_state_for_room(
- destination=dest, room_id=room_id, event_id=e_id
+ state, auth = await self._get_state_for_room(
+ destination=dest,
+ room_id=room_id,
+ event_id=e_id,
+ include_event_in_state=False,
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
@@ -718,95 +810,11 @@ class FederationHandler(BaseHandler):
auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
- missing_auth = required_auth - set(auth_events)
- failed_to_fetch = set()
-
- # Try and fetch any missing auth events from both DB and remote servers.
- # We repeatedly do this until we stop finding new auth events.
- while missing_auth - failed_to_fetch:
- logger.info("Missing auth for backfill: %r", missing_auth)
- ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
- auth_events.update(ret_events)
-
- required_auth.update(
- a_id for event in ret_events.values() for a_id in event.auth_event_ids()
- )
- missing_auth = required_auth - set(auth_events)
-
- if missing_auth - failed_to_fetch:
- logger.info(
- "Fetching missing auth for backfill: %r",
- missing_auth - failed_to_fetch,
- )
-
- results = yield make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self.federation_client.get_pdu,
- [dest],
- event_id,
- room_version=room_version,
- outlier=True,
- timeout=10000,
- )
- for event_id in missing_auth - failed_to_fetch
- ],
- consumeErrors=True,
- )
- ).addErrback(unwrapFirstError)
- auth_events.update({a.event_id: a for a in results if a})
- required_auth.update(
- a_id
- for event in results
- if event
- for a_id in event.auth_event_ids()
- )
- missing_auth = required_auth - set(auth_events)
-
- failed_to_fetch = missing_auth - set(auth_events)
-
- seen_events = yield self.store.have_seen_events(
- set(auth_events.keys()) | set(state_events.keys())
- )
- # We now have a chunk of events plus associated state and auth chain to
- # persist. We do the persistence in two steps:
- # 1. Auth events and state get persisted as outliers, plus the
- # backward extremities get persisted (as non-outliers).
- # 2. The rest of the events in the chunk get persisted one by one, as
- # each one depends on the previous event for its state.
- #
- # The important thing is that events in the chunk get persisted as
- # non-outliers, including when those events are also in the state or
- # auth chain. Caution must therefore be taken to ensure that they are
- # not accidentally marked as outliers.
-
- # Step 1a: persist auth events that *don't* appear in the chunk
ev_infos = []
- for a in auth_events.values():
- # We only want to persist auth events as outliers that we haven't
- # seen and aren't about to persist as part of the backfilled chunk.
- if a.event_id in seen_events or a.event_id in event_map:
- continue
- a.internal_metadata.outlier = True
- ev_infos.append(
- {
- "event": a,
- "auth_events": {
- (
- auth_events[a_id].type,
- auth_events[a_id].state_key,
- ): auth_events[a_id]
- for a_id in a.auth_event_ids()
- if a_id in auth_events
- },
- }
- )
-
- # Step 1b: persist the events in the chunk we fetched state for (i.e.
- # the backwards extremities) as non-outliers.
+ # Step 1: persist the events in the chunk we fetched state for (i.e.
+ # the backwards extremities), with custom auth events and state
for e_id in events_to_state:
# For paranoia we ensure that these events are marked as
# non-outliers
@@ -814,10 +822,10 @@ class FederationHandler(BaseHandler):
assert not ev.internal_metadata.is_outlier()
ev_infos.append(
- {
- "event": ev,
- "state": events_to_state[e_id],
- "auth_events": {
+ _NewEventInfo(
+ event=ev,
+ state=events_to_state[e_id],
+ auth_events={
(
auth_events[a_id].type,
auth_events[a_id].state_key,
@@ -825,10 +833,10 @@ class FederationHandler(BaseHandler):
for a_id in ev.auth_event_ids()
if a_id in auth_events
},
- }
+ )
)
- yield self._handle_new_events(dest, ev_infos, backfilled=True)
+ await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -844,16 +852,15 @@ class FederationHandler(BaseHandler):
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
- yield self._handle_new_event(dest, event, backfilled=True)
+ await self._handle_new_event(dest, event, backfilled=True)
return events
- @defer.inlineCallbacks
- def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
- extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+ extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
@@ -885,15 +892,17 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event
# types have.
- forward_events = yield self.store.get_successor_events(list(extremities))
+ forward_events = await self.store.get_successor_events(list(extremities))
- extremities_events = yield self.store.get_events(
- forward_events, check_redacted=False, get_prev_content=False
+ extremities_events = await self.store.get_events(
+ forward_events,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
- filtered_extremities = yield filter_events_for_server(
+ filtered_extremities = await filter_events_for_server(
self.storage,
self.server_name,
list(extremities_events.values()),
@@ -923,7 +932,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
- curr_state = yield self.state_handler.get_current_state(room_id)
+ curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
"""Get joined domains from state
@@ -962,12 +971,11 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
- @defer.inlineCallbacks
- def try_backfill(domains):
+ async def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
- yield self.backfill(
+ await self.backfill(
dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
@@ -998,7 +1006,7 @@ class FederationHandler(BaseHandler):
return False
- success = yield try_backfill(likely_domains)
+ success = await try_backfill(likely_domains)
if success:
return True
@@ -1012,7 +1020,7 @@ class FederationHandler(BaseHandler):
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
- states = yield make_deferred_yieldable(
+ states = await make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
@@ -1022,7 +1030,7 @@ class FederationHandler(BaseHandler):
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
- state_map = yield self.store.get_events(
+ state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False,
)
@@ -1038,7 +1046,7 @@ class FederationHandler(BaseHandler):
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
- success = yield try_backfill(
+ success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains]
)
if success:
@@ -1048,6 +1056,56 @@ class FederationHandler(BaseHandler):
return False
+ async def _get_events_and_persist(
+ self, destination: str, room_id: str, events: Iterable[str]
+ ):
+ """Fetch the given events from a server, and persist them as outliers.
+
+ Logs a warning if we can't find the given event.
+ """
+
+ room_version = await self.store.get_room_version(room_id)
+
+ event_infos = []
+
+ async def get_event(event_id: str):
+ with nested_logging_context(event_id):
+ try:
+ event = await self.federation_client.get_pdu(
+ [destination], event_id, room_version, outlier=True,
+ )
+ if event is None:
+ logger.warning(
+ "Server %s didn't return event %s", destination, event_id,
+ )
+ return
+
+ # recursively fetch the auth events for this event
+ auth_events = await self._get_events_from_store_or_dest(
+ destination, room_id, event.auth_event_ids()
+ )
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = auth_events.get(auth_event_id)
+ if ae:
+ auth[(ae.type, ae.state_key)] = ae
+
+ event_infos.append(_NewEventInfo(event, None, auth))
+
+ except Exception as e:
+ logger.warning(
+ "Error fetching missing state/auth event %s: %s %s",
+ event_id,
+ type(e),
+ e,
+ )
+
+ await concurrently_execute(get_event, events, 5)
+
+ await self._handle_new_events(
+ destination, event_infos,
+ )
+
def _sanity_check_event(self, ev):
"""
Do some early sanity checks of a received event
@@ -1187,7 +1245,7 @@ class FederationHandler(BaseHandler):
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = yield self.store.get_room_predecessor(room_id)
- if not predecessor:
+ if not predecessor or not isinstance(predecessor.get("room_id"), str):
return
old_room_id = predecessor["room_id"]
logger.debug(
@@ -1215,8 +1273,7 @@ class FederationHandler(BaseHandler):
return True
- @defer.inlineCallbacks
- def _handle_queued_pdus(self, room_queue):
+ async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
Args:
@@ -1232,7 +1289,7 @@ class FederationHandler(BaseHandler):
p.room_id,
)
with nested_logging_context(p.event_id):
- yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+ await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1362,7 +1419,7 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield self.user_joined_room(user, event.room_id)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
auth_chain = yield self.store.get_auth_chain(state_ids)
@@ -1428,9 +1485,9 @@ class FederationHandler(BaseHandler):
return event
@defer.inlineCallbacks
- def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
+ def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
origin, event, event_format_version = yield self._make_and_verify_event(
- target_hosts, room_id, user_id, "leave"
+ target_hosts, room_id, user_id, "leave", content=content
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -1710,7 +1767,12 @@ class FederationHandler(BaseHandler):
return context
@defer.inlineCallbacks
- def _handle_new_events(self, origin, event_infos, backfilled=False):
+ def _handle_new_events(
+ self,
+ origin: str,
+ event_infos: Iterable[_NewEventInfo],
+ backfilled: bool = False,
+ ):
"""Creates the appropriate contexts and persists events. The events
should not depend on one another, e.g. this should be used to persist
a bunch of outliers, but not a chunk of individual events that depend
@@ -1720,14 +1782,14 @@ class FederationHandler(BaseHandler):
"""
@defer.inlineCallbacks
- def prep(ev_info):
- event = ev_info["event"]
+ def prep(ev_info: _NewEventInfo):
+ event = ev_info.event
with nested_logging_context(suffix=event.event_id):
res = yield self._prep_event(
origin,
event,
- state=ev_info.get("state"),
- auth_events=ev_info.get("auth_events"),
+ state=ev_info.state,
+ auth_events=ev_info.auth_events,
backfilled=backfilled,
)
return res
@@ -1741,7 +1803,7 @@ class FederationHandler(BaseHandler):
yield self.persist_events_and_notify(
[
- (ev_info["event"], context)
+ (ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
],
backfilled=backfilled,
@@ -1843,7 +1905,14 @@ class FederationHandler(BaseHandler):
yield self.persist_events_and_notify([(event, new_event_context)])
@defer.inlineCallbacks
- def _prep_event(self, origin, event, state, auth_events, backfilled):
+ def _prep_event(
+ self,
+ origin: str,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
+ auth_events: Optional[StateMap[EventBase]],
+ backfilled: bool,
+ ):
"""
Args:
@@ -1851,7 +1920,7 @@ class FederationHandler(BaseHandler):
event:
state:
auth_events:
- backfilled (bool)
+ backfilled:
Returns:
Deferred, which resolves to synapse.events.snapshot.EventContext
@@ -1859,7 +1928,7 @@ class FederationHandler(BaseHandler):
context = yield self.state_handler.compute_event_context(event, old_state=state)
if not auth_events:
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
auth_events_ids = yield self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
@@ -1887,15 +1956,16 @@ class FederationHandler(BaseHandler):
return context
@defer.inlineCallbacks
- def _check_for_soft_fail(self, event, state, backfilled):
+ def _check_for_soft_fail(
+ self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
+ ):
"""Checks if we should soft fail the event, if so marks the event as
such.
Args:
- event (FrozenEvent)
- state (dict|None): The state at the event if we don't have all the
- event's prev events
- backfilled (bool): Whether the event is from backfill
+ event
+ state: The state at the event if we don't have all the event's prev events
+ backfilled: Whether the event is from backfill
Returns:
Deferred
@@ -2040,8 +2110,10 @@ class FederationHandler(BaseHandler):
auth_events (dict[(str, str)->synapse.events.EventBase]):
Map from (event_type, state_key) to event
- What we expect the event's auth_events to be, based on the event's
- position in the dag. I think? maybe??
+ Normally, our calculated auth_events based on the state of the room
+ at the event's position in the DAG, though occasionally (eg if the
+ event is an outlier), may be the auth events claimed by the remote
+ server.
Also NB that this function adds entries to it.
Returns:
@@ -2091,35 +2163,35 @@ class FederationHandler(BaseHandler):
origin (str):
event (synapse.events.EventBase):
context (synapse.events.snapshot.EventContext):
+
auth_events (dict[(str, str)->synapse.events.EventBase]):
+ Map from (event_type, state_key) to event
+
+ Normally, our calculated auth_events based on the state of the room
+ at the event's position in the DAG, though occasionally (eg if the
+ event is an outlier), may be the auth events claimed by the remote
+ server.
+
+ Also NB that this function adds entries to it.
Returns:
defer.Deferred[EventContext]: updated context
"""
event_auth_events = set(event.auth_event_ids())
- if event.is_state():
- event_key = (event.type, event.state_key)
- else:
- event_key = None
-
- # if the event's auth_events refers to events which are not in our
- # calculated auth_events, we need to fetch those events from somewhere.
- #
- # we start by fetching them from the store, and then try calling /event_auth/.
+ # missing_auth is the set of the event's auth_events which we don't yet have
+ # in auth_events.
missing_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
+ # if we have missing events, we need to fetch those events from somewhere.
+ #
+ # we start by checking if they are in the store, and then try calling /event_auth/.
if missing_auth:
- # TODO: can we use store.have_seen_events here instead?
- have_events = yield self.store.get_seen_events_with_rejections(missing_auth)
- logger.debug("Got events %s from store", have_events)
- missing_auth.difference_update(have_events.keys())
- else:
- have_events = {}
-
- have_events.update({e.event_id: "" for e in auth_events.values()})
+ have_events = yield self.store.have_seen_events(missing_auth)
+ logger.debug("Events %s are in the store", have_events)
+ missing_auth.difference_update(have_events)
if missing_auth:
# If we don't have all the auth events, we need to get them.
@@ -2165,19 +2237,18 @@ class FederationHandler(BaseHandler):
except AuthError:
pass
- have_events = yield self.store.get_seen_events_with_rejections(
- event.auth_event_ids()
- )
except Exception:
- # FIXME:
logger.exception("Failed to get auth chain")
if event.internal_metadata.is_outlier():
+ # XXX: given that, for an outlier, we'll be working with the
+ # event's *claimed* auth events rather than those we calculated:
+ # (a) is there any point in this test, since different_auth below will
+ # obviously be empty
+ # (b) alternatively, why don't we do it earlier?
logger.info("Skipping auth_event fetch for outlier")
return context
- # FIXME: Assumes we have and stored all the state for all the
- # prev_events
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
@@ -2191,53 +2262,58 @@ class FederationHandler(BaseHandler):
different_auth,
)
- room_version = yield self.store.get_room_version(event.room_id)
+ # XXX: currently this checks for redactions but I'm not convinced that is
+ # necessary?
+ different_events = yield self.store.get_events_as_list(different_auth)
- different_events = yield make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self.store.get_event, d, allow_none=True, allow_rejected=False
- )
- for d in different_auth
- if d in have_events and not have_events[d]
- ],
- consumeErrors=True,
- )
- ).addErrback(unwrapFirstError)
+ for d in different_events:
+ if d.room_id != event.room_id:
+ logger.warning(
+ "Event %s refers to auth_event %s which is in a different room",
+ event.event_id,
+ d.event_id,
+ )
- if different_events:
- local_view = dict(auth_events)
- remote_view = dict(auth_events)
- remote_view.update(
- {(d.type, d.state_key): d for d in different_events if d}
- )
+ # don't attempt to resolve the claimed auth events against our own
+ # in this case: just use our own auth events.
+ #
+ # XXX: should we reject the event in this case? It feels like we should,
+ # but then shouldn't we also do so if we've failed to fetch any of the
+ # auth events?
+ return context
- new_state = yield self.state_handler.resolve_events(
- room_version,
- [list(local_view.values()), list(remote_view.values())],
- event,
- )
+ # now we state-resolve between our own idea of the auth events, and the remote's
+ # idea of them.
- logger.info(
- "After state res: updating auth_events with new state %s",
- {
- (d.type, d.state_key): d.event_id
- for d in new_state.values()
- if auth_events.get((d.type, d.state_key)) != d
- },
- )
+ local_state = auth_events.values()
+ remote_auth_events = dict(auth_events)
+ remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
+ remote_state = remote_auth_events.values()
- auth_events.update(new_state)
+ room_version = yield self.store.get_room_version(event.room_id)
+ new_state = yield self.state_handler.resolve_events(
+ room_version, (local_state, remote_state), event
+ )
- context = yield self._update_context_for_auth_events(
- event, context, auth_events, event_key
- )
+ logger.info(
+ "After state res: updating auth_events with new state %s",
+ {
+ (d.type, d.state_key): d.event_id
+ for d in new_state.values()
+ if auth_events.get((d.type, d.state_key)) != d
+ },
+ )
+
+ auth_events.update(new_state)
+
+ context = yield self._update_context_for_auth_events(
+ event, context, auth_events
+ )
return context
@defer.inlineCallbacks
- def _update_context_for_auth_events(self, event, context, auth_events, event_key):
+ def _update_context_for_auth_events(self, event, context, auth_events):
"""Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
@@ -2246,24 +2322,27 @@ class FederationHandler(BaseHandler):
context (synapse.events.snapshot.EventContext): initial event context
- auth_events (dict[(str, str)->str]): Events to update in the event
+ auth_events (dict[(str, str)->EventBase]): Events to update in the event
context.
- event_key ((str, str)): (type, state_key) for the current event.
- this will not be included in the current_state in the context.
-
Returns:
Deferred[EventContext]: new event context
"""
+ # exclude the state key of the new event from the current_state in the context.
+ if event.is_state():
+ event_key = (event.type, event.state_key)
+ else:
+ event_key = None
state_updates = {
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
}
- current_state_ids = yield context.get_current_state_ids(self.store)
+
+ current_state_ids = yield context.get_current_state_ids()
current_state_ids = dict(current_state_ids)
current_state_ids.update(state_updates)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = dict(prev_state_ids)
prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
@@ -2459,7 +2538,7 @@ class FederationHandler(BaseHandler):
room_version, event_dict, event, context
)
- EventValidator().validate_new(event)
+ EventValidator().validate_new(event, self.config)
# We need to tell the transaction queue to send this out, even
# though the sender isn't a local user.
@@ -2547,7 +2626,7 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
@@ -2574,7 +2653,7 @@ class FederationHandler(BaseHandler):
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
- EventValidator().validate_new(event)
+ EventValidator().validate_new(event, self.config)
return (event, context)
@defer.inlineCallbacks
@@ -2595,7 +2674,7 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
invite_event = None
@@ -2708,6 +2787,11 @@ class FederationHandler(BaseHandler):
event_and_contexts, backfilled=backfilled
)
+ if self._ephemeral_messages_enabled:
+ for (event, context) in event_and_contexts:
+ # If there's an expiry timestamp on the event, schedule its expiry.
+ self._message_handler.maybe_schedule_expiry(event)
+
if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts:
yield self._notify_persisted_event(event, max_stream_id)
@@ -2764,7 +2848,7 @@ class FederationHandler(BaseHandler):
room_id=room_id, user_id=user.to_string(), change="joined"
)
else:
- return user_joined_room(self.distributor, user, room_id)
+ return defer.succeed(user_joined_room(self.distributor, user, room_id))
@defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id):
|