diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0dcf9abf8..3992b4791b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
import itertools
import logging
-from typing import Dict, Iterable, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six
from six import iteritems, itervalues
@@ -63,9 +63,10 @@ from synapse.replication.http.federation import (
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id
from synapse.util import batch_iter, unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
@@ -164,8 +165,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
- @defer.inlineCallbacks
- def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -175,17 +175,15 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
-
- Returns (Deferred): completes with None
"""
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+ logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers
- existing = yield self.store.get_event(
+ existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
@@ -229,7 +227,7 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+ is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -245,12 +243,12 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
- min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
@@ -270,7 +268,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
shortstr(missing_prevs),
)
- with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id,
@@ -278,13 +276,19 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
)
- yield self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ )
# Update the set of things we've seen after trying to
# fetch the missing stuff
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@@ -292,14 +296,6 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
- elif missing_prevs:
- logger.info(
- "[%s %s] Not recursively fetching %d missing prev_events: %s",
- room_id,
- event_id,
- len(missing_prevs),
- shortstr(missing_prevs),
- )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -350,7 +346,7 @@ class FederationHandler(BaseHandler):
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
- ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
@@ -370,7 +366,7 @@ class FederationHandler(BaseHandler):
p,
)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
@@ -379,11 +375,11 @@ class FederationHandler(BaseHandler):
(
remote_state,
got_auth_chain,
- ) = yield self._get_state_for_room(origin, room_id, p)
+ ) = await self._get_state_for_room(origin, room_id, p)
# we want the state *after* p; _get_state_for_room returns the
# state *before* p.
- remote_event = yield self.federation_client.get_pdu(
+ remote_event = await self.federation_client.get_pdu(
[origin], p, room_version, outlier=True
)
@@ -408,7 +404,8 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
- state_map = yield resolve_events_with_store(
+ state_map = await resolve_events_with_store(
+ room_id,
room_version,
state_maps,
event_map,
@@ -420,7 +417,7 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
- evs = yield self.store.get_events(
+ evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
check_redacted=False,
@@ -444,12 +441,11 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
- yield self._process_received_pdu(
+ await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain
)
- @defer.inlineCallbacks
- def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
@@ -461,12 +457,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
return
- latest = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@@ -530,7 +526,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens.
try:
- missing_events = yield self.federation_client.get_missing_events(
+ missing_events = await self.federation_client.get_missing_events(
origin,
room_id,
earliest_events_ids=list(latest),
@@ -569,7 +565,7 @@ class FederationHandler(BaseHandler):
)
with nested_logging_context(ev.event_id):
try:
- yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+ await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warning(
@@ -676,47 +672,21 @@ class FederationHandler(BaseHandler):
def _process_received_pdu(self, origin, event, state, auth_chain):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
+
+ Args:
+ origin: server sending the event
+
+ event: event to be persisted
+
+ state: Normally None, but if we are handling a gap in the graph
+ (ie, we are missing one or more prev_events), the resolved state at the
+ event
"""
room_id = event.room_id
event_id = event.event_id
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
- event_ids = set()
- if state:
- event_ids |= {e.event_id for e in state}
- if auth_chain:
- event_ids |= {e.event_id for e in auth_chain}
-
- seen_ids = yield self.store.have_seen_events(event_ids)
-
- if state and auth_chain is not None:
- # If we have any state or auth_chain given to us by the replication
- # layer, then we should handle them (if we haven't before.)
-
- event_infos = []
-
- for e in itertools.chain(auth_chain, state):
- if e.event_id in seen_ids:
- continue
- e.internal_metadata.outlier = True
- auth_ids = e.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in auth_chain
- if e.event_id in auth_ids or e.type == EventTypes.Create
- }
- event_infos.append(_NewEventInfo(event=e, auth_events=auth))
- seen_ids.add(e.event_id)
-
- logger.info(
- "[%s %s] persisting newly-received auth/state events %s",
- room_id,
- event_id,
- [e.event.event_id for e in event_infos],
- )
- yield self._handle_new_events(origin, event_infos)
-
try:
context = yield self._handle_new_event(origin, event, state=state)
except AuthError as e:
@@ -754,8 +724,7 @@ class FederationHandler(BaseHandler):
yield self.user_joined_room(user, room_id)
@log_function
- @defer.inlineCallbacks
- def backfill(self, dest, room_id, limit, extremities):
+ async def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -772,9 +741,7 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
- room_version = yield self.store.get_room_version(room_id)
-
- events = yield self.federation_client.backfill(
+ events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
@@ -789,7 +756,7 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev)
# Don't bother processing events we already have.
- seen_events = yield self.store.have_events_in_timeline(
+ seen_events = await self.store.have_events_in_timeline(
set(e.event_id for e in events)
)
@@ -802,6 +769,9 @@ class FederationHandler(BaseHandler):
event_ids = set(e.event_id for e in events)
+ # build a list of events whose prev_events weren't in the batch.
+ # (XXX: this will include events whose prev_events we already have; that doesn't
+ # sound right?)
edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
@@ -812,7 +782,7 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = yield self._get_state_for_room(
+ state, auth = await self._get_state_for_room(
destination=dest, room_id=room_id, event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
@@ -830,95 +800,11 @@ class FederationHandler(BaseHandler):
auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
- missing_auth = required_auth - set(auth_events)
- failed_to_fetch = set()
-
- # Try and fetch any missing auth events from both DB and remote servers.
- # We repeatedly do this until we stop finding new auth events.
- while missing_auth - failed_to_fetch:
- logger.info("Missing auth for backfill: %r", missing_auth)
- ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
- auth_events.update(ret_events)
- required_auth.update(
- a_id for event in ret_events.values() for a_id in event.auth_event_ids()
- )
- missing_auth = required_auth - set(auth_events)
-
- if missing_auth - failed_to_fetch:
- logger.info(
- "Fetching missing auth for backfill: %r",
- missing_auth - failed_to_fetch,
- )
-
- results = yield make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self.federation_client.get_pdu,
- [dest],
- event_id,
- room_version=room_version,
- outlier=True,
- timeout=10000,
- )
- for event_id in missing_auth - failed_to_fetch
- ],
- consumeErrors=True,
- )
- ).addErrback(unwrapFirstError)
- auth_events.update({a.event_id: a for a in results if a})
- required_auth.update(
- a_id
- for event in results
- if event
- for a_id in event.auth_event_ids()
- )
- missing_auth = required_auth - set(auth_events)
-
- failed_to_fetch = missing_auth - set(auth_events)
-
- seen_events = yield self.store.have_seen_events(
- set(auth_events.keys()) | set(state_events.keys())
- )
-
- # We now have a chunk of events plus associated state and auth chain to
- # persist. We do the persistence in two steps:
- # 1. Auth events and state get persisted as outliers, plus the
- # backward extremities get persisted (as non-outliers).
- # 2. The rest of the events in the chunk get persisted one by one, as
- # each one depends on the previous event for its state.
- #
- # The important thing is that events in the chunk get persisted as
- # non-outliers, including when those events are also in the state or
- # auth chain. Caution must therefore be taken to ensure that they are
- # not accidentally marked as outliers.
-
- # Step 1a: persist auth events that *don't* appear in the chunk
ev_infos = []
- for a in auth_events.values():
- # We only want to persist auth events as outliers that we haven't
- # seen and aren't about to persist as part of the backfilled chunk.
- if a.event_id in seen_events or a.event_id in event_map:
- continue
- a.internal_metadata.outlier = True
- ev_infos.append(
- _NewEventInfo(
- event=a,
- auth_events={
- (
- auth_events[a_id].type,
- auth_events[a_id].state_key,
- ): auth_events[a_id]
- for a_id in a.auth_event_ids()
- if a_id in auth_events
- },
- )
- )
-
- # Step 1b: persist the events in the chunk we fetched state for (i.e.
- # the backwards extremities) as non-outliers.
+ # Step 1: persist the events in the chunk we fetched state for (i.e.
+ # the backwards extremities), with custom auth events and state
for e_id in events_to_state:
# For paranoia we ensure that these events are marked as
# non-outliers
@@ -940,7 +826,7 @@ class FederationHandler(BaseHandler):
)
)
- yield self._handle_new_events(dest, ev_infos, backfilled=True)
+ await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -956,16 +842,15 @@ class FederationHandler(BaseHandler):
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
- yield self._handle_new_event(dest, event, backfilled=True)
+ await self._handle_new_event(dest, event, backfilled=True)
return events
- @defer.inlineCallbacks
- def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
- extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+ extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
@@ -997,15 +882,17 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event
# types have.
- forward_events = yield self.store.get_successor_events(list(extremities))
+ forward_events = await self.store.get_successor_events(list(extremities))
- extremities_events = yield self.store.get_events(
- forward_events, check_redacted=False, get_prev_content=False
+ extremities_events = await self.store.get_events(
+ forward_events,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
- filtered_extremities = yield filter_events_for_server(
+ filtered_extremities = await filter_events_for_server(
self.storage,
self.server_name,
list(extremities_events.values()),
@@ -1035,7 +922,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
- curr_state = yield self.state_handler.get_current_state(room_id)
+ curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
"""Get joined domains from state
@@ -1074,12 +961,11 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
- @defer.inlineCallbacks
- def try_backfill(domains):
+ async def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
- yield self.backfill(
+ await self.backfill(
dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
@@ -1110,7 +996,7 @@ class FederationHandler(BaseHandler):
return False
- success = yield try_backfill(likely_domains)
+ success = await try_backfill(likely_domains)
if success:
return True
@@ -1124,7 +1010,7 @@ class FederationHandler(BaseHandler):
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
- states = yield make_deferred_yieldable(
+ states = await make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
@@ -1134,7 +1020,7 @@ class FederationHandler(BaseHandler):
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
- state_map = yield self.store.get_events(
+ state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False,
)
@@ -1150,7 +1036,7 @@ class FederationHandler(BaseHandler):
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
- success = yield try_backfill(
+ success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains]
)
if success:
@@ -1160,6 +1046,56 @@ class FederationHandler(BaseHandler):
return False
+ async def _get_events_and_persist(
+ self, destination: str, room_id: str, events: Iterable[str]
+ ):
+ """Fetch the given events from a server, and persist them as outliers.
+
+ Logs a warning if we can't find the given event.
+ """
+
+ room_version = await self.store.get_room_version(room_id)
+
+ event_infos = []
+
+ async def get_event(event_id: str):
+ with nested_logging_context(event_id):
+ try:
+ event = await self.federation_client.get_pdu(
+ [destination], event_id, room_version, outlier=True,
+ )
+ if event is None:
+ logger.warning(
+ "Server %s didn't return event %s", destination, event_id,
+ )
+ return
+
+ # recursively fetch the auth events for this event
+ auth_events = await self._get_events_from_store_or_dest(
+ destination, room_id, event.auth_event_ids()
+ )
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = auth_events.get(auth_event_id)
+ if ae:
+ auth[(ae.type, ae.state_key)] = ae
+
+ event_infos.append(_NewEventInfo(event, None, auth))
+
+ except Exception as e:
+ logger.warning(
+ "Error fetching missing state/auth event %s: %s %s",
+ event_id,
+ type(e),
+ e,
+ )
+
+ await concurrently_execute(get_event, events, 5)
+
+ await self._handle_new_events(
+ destination, event_infos,
+ )
+
def _sanity_check_event(self, ev):
"""
Do some early sanity checks of a received event
@@ -1299,7 +1235,7 @@ class FederationHandler(BaseHandler):
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = yield self.store.get_room_predecessor(room_id)
- if not predecessor:
+ if not predecessor or not isinstance(predecessor.get("room_id"), str):
return
old_room_id = predecessor["room_id"]
logger.debug(
@@ -1327,8 +1263,7 @@ class FederationHandler(BaseHandler):
return True
- @defer.inlineCallbacks
- def _handle_queued_pdus(self, room_queue):
+ async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
Args:
@@ -1344,7 +1279,7 @@ class FederationHandler(BaseHandler):
p.room_id,
)
with nested_logging_context(p.event_id):
- yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+ await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1501,8 +1436,15 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
+ is_published = yield self.store.is_room_published(event.room_id)
+
if not self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id
+ event.sender,
+ event.state_key,
+ None,
+ room_id=event.room_id,
+ new_room=False,
+ published_room=is_published,
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1542,7 +1484,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
origin, event, event_format_version = yield self._make_and_verify_event(
- target_hosts, room_id, user_id, "leave", content=content,
+ target_hosts, room_id, user_id, "leave", content=content
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -2903,7 +2845,7 @@ class FederationHandler(BaseHandler):
room_id=room_id, user_id=user.to_string(), change="joined"
)
else:
- return user_joined_room(self.distributor, user, room_id)
+ return defer.succeed(user_joined_room(self.distributor, user, room_id))
@defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id):
|