diff --git a/changelog.d/10692.misc b/changelog.d/10692.misc
new file mode 100644
index 0000000000..a1b0def76b
--- /dev/null
+++ b/changelog.d/10692.misc
@@ -0,0 +1 @@
+Split the event-processing methods in `FederationHandler` into a separate `FederationEventHandler`.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e1b58d40c5..214ee948fa 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -110,6 +110,7 @@ class FederationServer(FederationBase):
super().__init__(hs)
self.handler = hs.get_federation_handler()
+ self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
@@ -787,7 +788,9 @@ class FederationServer(FederationBase):
event = await self._check_sigs_and_hash(room_version, event)
- return await self.handler.on_send_membership_event(origin, event)
+ return await self._federation_event_handler.on_send_membership_event(
+ origin, event
+ )
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
@@ -1005,7 +1008,7 @@ class FederationServer(FederationBase):
async with lock:
logger.info("handling received PDU: %s", event)
try:
- await self.handler.on_receive_pdu(origin, event)
+ await self._federation_event_handler.on_receive_pdu(origin, event)
except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6fa2fc8f52..daf1d3bfb3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,23 +17,9 @@
import itertools
import logging
-from collections.abc import Container
from http import HTTPStatus
-from typing import (
- TYPE_CHECKING,
- Collection,
- Dict,
- Iterable,
- List,
- Optional,
- Sequence,
- Set,
- Tuple,
- Union,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
-import attr
-from prometheus_client import Counter
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@@ -41,19 +27,12 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
from synapse import event_auth
-from synapse.api.constants import (
- EventContentFields,
- EventTypes,
- Membership,
- RejectedReason,
- RoomEncryptionAlgorithms,
-)
+from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.api.errors import (
AuthError,
CodeMessageException,
Codes,
FederationDeniedError,
- FederationError,
HttpResponseException,
NotFoundError,
RequestSendFailed,
@@ -61,7 +40,6 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
-from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
@@ -75,28 +53,14 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.logging.utils import log_function
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
- ReplicationFederationSendEventsRestServlet,
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
-from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import (
- JsonDict,
- MutableStateMap,
- PersistedEventPosition,
- RoomStreamToken,
- StateMap,
- UserID,
- get_domain_from_id,
-)
-from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.iterutils import batch_iter
+from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
-from synapse.util.stringutils import shortstr
from synapse.visibility import filter_events_for_server
if TYPE_CHECKING:
@@ -104,45 +68,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-soft_failed_event_counter = Counter(
- "synapse_federation_soft_failed_events_total",
- "Events received over federation that we marked as soft_failed",
-)
-
-
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _NewEventInfo:
- """Holds information about a received event, ready for passing to _auth_and_persist_events
-
- Attributes:
- event: the received event
-
- claimed_auth_event_map: a map of (type, state_key) => event for the event's
- claimed auth_events.
-
- This can include events which have not yet been persisted, in the case that
- we are backfilling a batch of events.
-
- Note: May be incomplete: if we were unable to find all of the claimed auth
- events. Also, treat the contents with caution: the events might also have
- been rejected, might not yet have been authorized themselves, or they might
- be in the wrong room.
-
- """
-
- event: EventBase
- claimed_auth_event_map: StateMap[EventBase]
-
class FederationHandler(BaseHandler):
- """Handles events that originated from federation.
- Responsible for:
- a) handling received Pdus before handing them on as Events to the rest
- of the homeserver (including auth and state conflict resolutions)
- b) converting events that were produced by local clients that may need
- to be sent to remote homeservers.
- c) doing the necessary dances to invite remote users and join remote
- rooms.
+ """Handles general incoming federation requests
+
+ Incoming events are *not* handled here, for which see FederationEventHandler.
"""
def __init__(self, hs: "HomeServer"):
@@ -155,652 +85,35 @@ class FederationHandler(BaseHandler):
self.state_store = self.storage.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
- self._state_resolution_handler = hs.get_state_resolution_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
- self.action_generator = hs.get_action_generator()
self.is_mine_id = hs.is_mine_id
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self._event_auth_handler = hs.get_event_auth_handler()
- self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_proxied_blacklisted_http_client()
- self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler()
+ self._federation_event_handler = hs.get_federation_event_handler()
- self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
hs
)
if hs.config.worker_app:
- self._user_device_resync = (
- ReplicationUserDevicesResyncRestServlet.make_client(hs)
- )
self._maybe_store_room_on_outlier_membership = (
ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
)
else:
- self._device_list_updater = hs.get_device_handler().device_list_updater
self._maybe_store_room_on_outlier_membership = (
self.store.maybe_store_room_on_outlier_membership
)
- # When joining a room we need to queue any events for that room up.
- # For each room, a list of (pdu, origin) tuples.
- self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
- self._room_pdu_linearizer = Linearizer("fed_room_pdu")
-
self._room_backfill = Linearizer("room_backfill")
self.third_party_event_rules = hs.get_third_party_event_rules()
- self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
-
- async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
- """Process a PDU received via a federation /send/ transaction
-
- Args:
- origin: server which initiated the /send/ transaction. Will
- be used to fetch missing events or state.
- pdu: received PDU
- """
-
- room_id = pdu.room_id
- event_id = pdu.event_id
-
- # We reprocess pdus when we have seen them only as outliers
- existing = await self.store.get_event(
- event_id, allow_none=True, allow_rejected=True
- )
-
- # FIXME: Currently we fetch an event again when we already have it
- # if it has been marked as an outlier.
- if existing:
- if not existing.internal_metadata.is_outlier():
- logger.info(
- "Ignoring received event %s which we have already seen", event_id
- )
- return
- if pdu.internal_metadata.is_outlier():
- logger.info(
- "Ignoring received outlier %s which we already have as an outlier",
- event_id,
- )
- return
- logger.info("De-outliering event %s", event_id)
-
- # do some initial sanity-checking of the event. In particular, make
- # sure it doesn't have hundreds of prev_events or auth_events, which
- # could cause a huge state resolution or cascade of event fetches.
- try:
- self._sanity_check_event(pdu)
- except SynapseError as err:
- logger.warning("Received event failed sanity checks")
- raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
-
- # If we are currently in the process of joining this room, then we
- # queue up events for later processing.
- if room_id in self.room_queues:
- logger.info(
- "Queuing PDU from %s for now: join in progress",
- origin,
- )
- self.room_queues[room_id].append((pdu, origin))
- return
-
- # If we're not in the room just ditch the event entirely. This is
- # probably an old server that has come back and thinks we're still in
- # the room (or we've been rejoined to the room by a state reset).
- #
- # Note that if we were never in the room then we would have already
- # dropped the event, since we wouldn't know the room version.
- is_in_room = await self._event_auth_handler.check_host_in_room(
- room_id, self.server_name
- )
- if not is_in_room:
- logger.info(
- "Ignoring PDU from %s as we're not in the room",
- origin,
- )
- return None
-
- # Check that the event passes auth based on the state at the event. This is
- # done for events that are to be added to the timeline (non-outliers).
- #
- # Get missing pdus if necessary:
- # - Fetching any missing prev events to fill in gaps in the graph
- # - Fetching state if we have a hole in the graph
- if not pdu.internal_metadata.is_outlier():
- prevs = set(pdu.prev_event_ids())
- seen = await self.store.have_events_in_timeline(prevs)
- missing_prevs = prevs - seen
-
- if missing_prevs:
- # We only backfill backwards to the min depth.
- min_depth = await self.get_min_depth_for_context(pdu.room_id)
- logger.debug("min_depth: %d", min_depth)
-
- if min_depth is not None and pdu.depth > min_depth:
- # If we're missing stuff, ensure we only fetch stuff one
- # at a time.
- logger.info(
- "Acquiring room lock to fetch %d missing prev_events: %s",
- len(missing_prevs),
- shortstr(missing_prevs),
- )
- with (await self._room_pdu_linearizer.queue(pdu.room_id)):
- logger.info(
- "Acquired room lock to fetch %d missing prev_events",
- len(missing_prevs),
- )
-
- try:
- await self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
- except Exception as e:
- raise Exception(
- "Error fetching missing prev_events for %s: %s"
- % (event_id, e)
- ) from e
-
- # Update the set of things we've seen after trying to
- # fetch the missing stuff
- seen = await self.store.have_events_in_timeline(prevs)
- missing_prevs = prevs - seen
-
- if not missing_prevs:
- logger.info("Found all missing prev_events")
-
- if missing_prevs:
- # since this event was pushed to us, it is possible for it to
- # become the only forward-extremity in the room, and we would then
- # trust its state to be the state for the whole room. This is very
- # bad. Further, if the event was pushed to us, there is no excuse
- # for us not to have all the prev_events. (XXX: apart from
- # min_depth?)
- #
- # We therefore reject any such events.
- logger.warning(
- "Rejecting: failed to fetch %d prev events: %s",
- len(missing_prevs),
- shortstr(missing_prevs),
- )
- raise FederationError(
- "ERROR",
- 403,
- (
- "Your server isn't divulging details about prev_events "
- "referenced in this event."
- ),
- affected=pdu.event_id,
- )
-
- await self._process_received_pdu(origin, pdu, state=None)
-
- async def _get_missing_events_for_pdu(
- self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
- ) -> None:
- """
- Args:
- origin: Origin of the pdu. Will be called to get the missing events
- pdu: received pdu
- prevs: List of event ids which we are missing
- min_depth: Minimum depth of events to return.
- """
-
- room_id = pdu.room_id
- event_id = pdu.event_id
-
- seen = await self.store.have_events_in_timeline(prevs)
-
- if not prevs - seen:
- return
-
- latest_list = await self.store.get_latest_event_ids_in_room(room_id)
-
- # We add the prev events that we have seen to the latest
- # list to ensure the remote server doesn't give them to us
- latest = set(latest_list)
- latest |= seen
-
- logger.info(
- "Requesting missing events between %s and %s",
- shortstr(latest),
- event_id,
- )
-
- # XXX: we set timeout to 10s to help workaround
- # https://github.com/matrix-org/synapse/issues/1733.
- # The reason is to avoid holding the linearizer lock
- # whilst processing inbound /send transactions, causing
- # FDs to stack up and block other inbound transactions
- # which empirically can currently take up to 30 minutes.
- #
- # N.B. this explicitly disables retry attempts.
- #
- # N.B. this also increases our chances of falling back to
- # fetching fresh state for the room if the missing event
- # can't be found, which slightly reduces our security.
- # it may also increase our DAG extremity count for the room,
- # causing additional state resolution? See #1760.
- # However, fetching state doesn't hold the linearizer lock
- # apparently.
- #
- # see https://github.com/matrix-org/synapse/pull/1744
- #
- # ----
- #
- # Update richvdh 2018/09/18: There are a number of problems with timing this
- # request out aggressively on the client side:
- #
- # - it plays badly with the server-side rate-limiter, which starts tarpitting you
- # if you send too many requests at once, so you end up with the server carefully
- # working through the backlog of your requests, which you have already timed
- # out.
- #
- # - for this request in particular, we now (as of
- # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the
- # server can't produce a plausible-looking set of prev_events - so we becone
- # much more likely to reject the event.
- #
- # - contrary to what it says above, we do *not* fall back to fetching fresh state
- # for the room if get_missing_events times out. Rather, we give up processing
- # the PDU whose prevs we are missing, which then makes it much more likely that
- # we'll end up back here for the *next* PDU in the list, which exacerbates the
- # problem.
- #
- # - the aggressive 10s timeout was introduced to deal with incoming federation
- # requests taking 8 hours to process. It's not entirely clear why that was going
- # on; certainly there were other issues causing traffic storms which are now
- # resolved, and I think in any case we may be more sensible about our locking
- # now. We're *certainly* more sensible about our logging.
- #
- # All that said: Let's try increasing the timeout to 60s and see what happens.
-
- try:
- missing_events = await self.federation_client.get_missing_events(
- origin,
- room_id,
- earliest_events_ids=list(latest),
- latest_events=[pdu],
- limit=10,
- min_depth=min_depth,
- timeout=60000,
- )
- except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e:
- # We failed to get the missing events, but since we need to handle
- # the case of `get_missing_events` not returning the necessary
- # events anyway, it is safe to simply log the error and continue.
- logger.warning("Failed to get prev_events: %s", e)
- return
-
- logger.info("Got %d prev_events", len(missing_events))
- await self._process_pulled_events(origin, missing_events, backfilled=False)
-
- async def _get_state_after_missing_prev_event(
- self,
- destination: str,
- room_id: str,
- event_id: str,
- ) -> List[EventBase]:
- """Requests all of the room state at a given event from a remote homeserver.
-
- Args:
- destination: The remote homeserver to query for the state.
- room_id: The id of the room we're interested in.
- event_id: The id of the event we want the state at.
-
- Returns:
- A list of events in the state, including the event itself
- """
- (
- state_event_ids,
- auth_event_ids,
- ) = await self.federation_client.get_room_state_ids(
- destination, room_id, event_id=event_id
- )
-
- logger.debug(
- "state_ids returned %i state events, %i auth events",
- len(state_event_ids),
- len(auth_event_ids),
- )
-
- # start by just trying to fetch the events from the store
- desired_events = set(state_event_ids)
- desired_events.add(event_id)
- logger.debug("Fetching %i events from cache/store", len(desired_events))
- fetched_events = await self.store.get_events(
- desired_events, allow_rejected=True
- )
-
- missing_desired_events = desired_events - fetched_events.keys()
- logger.debug(
- "We are missing %i events (got %i)",
- len(missing_desired_events),
- len(fetched_events),
- )
-
- # We probably won't need most of the auth events, so let's just check which
- # we have for now, rather than thrashing the event cache with them all
- # unnecessarily.
-
- # TODO: we probably won't actually need all of the auth events, since we
- # already have a bunch of the state events. It would be nice if the
- # federation api gave us a way of finding out which we actually need.
-
- missing_auth_events = set(auth_event_ids) - fetched_events.keys()
- missing_auth_events.difference_update(
- await self.store.have_seen_events(room_id, missing_auth_events)
- )
- logger.debug("We are also missing %i auth events", len(missing_auth_events))
-
- missing_events = missing_desired_events | missing_auth_events
- logger.debug("Fetching %i events from remote", len(missing_events))
- await self._get_events_and_persist(
- destination=destination, room_id=room_id, events=missing_events
- )
-
- # we need to make sure we re-load from the database to get the rejected
- # state correct.
- fetched_events.update(
- await self.store.get_events(missing_desired_events, allow_rejected=True)
- )
-
- # check for events which were in the wrong room.
- #
- # this can happen if a remote server claims that the state or
- # auth_events at an event in room A are actually events in room B
-
- bad_events = [
- (event_id, event.room_id)
- for event_id, event in fetched_events.items()
- if event.room_id != room_id
- ]
-
- for bad_event_id, bad_room_id in bad_events:
- # This is a bogus situation, but since we may only discover it a long time
- # after it happened, we try our best to carry on, by just omitting the
- # bad events from the returned state set.
- logger.warning(
- "Remote server %s claims event %s in room %s is an auth/state "
- "event in room %s",
- destination,
- bad_event_id,
- bad_room_id,
- room_id,
- )
-
- del fetched_events[bad_event_id]
-
- # if we couldn't get the prev event in question, that's a problem.
- remote_event = fetched_events.get(event_id)
- if not remote_event:
- raise Exception("Unable to get missing prev_event %s" % (event_id,))
-
- # missing state at that event is a warning, not a blocker
- # XXX: this doesn't sound right? it means that we'll end up with incomplete
- # state.
- failed_to_fetch = desired_events - fetched_events.keys()
- if failed_to_fetch:
- logger.warning(
- "Failed to fetch missing state events for %s %s",
- event_id,
- failed_to_fetch,
- )
-
- remote_state = [
- fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
- ]
-
- if remote_event.is_state() and remote_event.rejected_reason is None:
- remote_state.append(remote_event)
-
- return remote_state
-
- async def _process_received_pdu(
- self,
- origin: str,
- event: EventBase,
- state: Optional[Iterable[EventBase]],
- backfilled: bool = False,
- ) -> None:
- """Called when we have a new pdu. We need to do auth checks and put it
- through the StateHandler.
-
- Args:
- origin: server sending the event
-
- event: event to be persisted
-
- state: Normally None, but if we are handling a gap in the graph
- (ie, we are missing one or more prev_events), the resolved state at the
- event
-
- backfilled: True if this is part of a historical batch of events (inhibits
- notification to clients, and validation of device keys.)
- """
- logger.debug("Processing event: %s", event)
-
- try:
- context = await self.state_handler.compute_event_context(
- event, old_state=state
- )
- await self._auth_and_persist_event(
- origin, event, context, state=state, backfilled=backfilled
- )
- except AuthError as e:
- raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
-
- if backfilled:
- return
-
- # For encrypted messages we check that we know about the sending device,
- # if we don't then we mark the device cache for that user as stale.
- if event.type == EventTypes.Encrypted:
- device_id = event.content.get("device_id")
- sender_key = event.content.get("sender_key")
-
- cached_devices = await self.store.get_cached_devices_for_user(event.sender)
-
- resync = False # Whether we should resync device lists.
-
- device = None
- if device_id is not None:
- device = cached_devices.get(device_id)
- if device is None:
- logger.info(
- "Received event from remote device not in our cache: %s %s",
- event.sender,
- device_id,
- )
- resync = True
-
- # We also check if the `sender_key` matches what we expect.
- if sender_key is not None:
- # Figure out what sender key we're expecting. If we know the
- # device and recognize the algorithm then we can work out the
- # exact key to expect. Otherwise check it matches any key we
- # have for that device.
-
- current_keys: Container[str] = []
-
- if device:
- keys = device.get("keys", {}).get("keys", {})
-
- if (
- event.content.get("algorithm")
- == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2
- ):
- # For this algorithm we expect a curve25519 key.
- key_name = "curve25519:%s" % (device_id,)
- current_keys = [keys.get(key_name)]
- else:
- # We don't know understand the algorithm, so we just
- # check it matches a key for the device.
- current_keys = keys.values()
- elif device_id:
- # We don't have any keys for the device ID.
- pass
- else:
- # The event didn't include a device ID, so we just look for
- # keys across all devices.
- current_keys = [
- key
- for device in cached_devices.values()
- for key in device.get("keys", {}).get("keys", {}).values()
- ]
-
- # We now check that the sender key matches (one of) the expected
- # keys.
- if sender_key not in current_keys:
- logger.info(
- "Received event from remote device with unexpected sender key: %s %s: %s",
- event.sender,
- device_id or "<no device_id>",
- sender_key,
- )
- resync = True
-
- if resync:
- run_as_background_process(
- "resync_device_due_to_pdu", self._resync_device, event.sender
- )
-
- await self._handle_marker_event(origin, event)
-
- async def _handle_marker_event(self, origin: str, marker_event: EventBase):
- """Handles backfilling the insertion event when we receive a marker
- event that points to one.
-
- Args:
- origin: Origin of the event. Will be called to get the insertion event
- marker_event: The event to process
- """
-
- if marker_event.type != EventTypes.MSC2716_MARKER:
- # Not a marker event
- return
-
- if marker_event.rejected_reason is not None:
- # Rejected event
- return
-
- # Skip processing a marker event if the room version doesn't
- # support it.
- room_version = await self.store.get_room_version(marker_event.room_id)
- if not room_version.msc2716_historical:
- return
-
- logger.debug("_handle_marker_event: received %s", marker_event)
-
- insertion_event_id = marker_event.content.get(
- EventContentFields.MSC2716_MARKER_INSERTION
- )
-
- if insertion_event_id is None:
- # Nothing to retrieve then (invalid marker)
- return
-
- logger.debug(
- "_handle_marker_event: backfilling insertion event %s", insertion_event_id
- )
-
- await self._get_events_and_persist(
- origin,
- marker_event.room_id,
- [insertion_event_id],
- )
-
- insertion_event = await self.store.get_event(
- insertion_event_id, allow_none=True
- )
- if insertion_event is None:
- logger.warning(
- "_handle_marker_event: server %s didn't return insertion event %s for marker %s",
- origin,
- insertion_event_id,
- marker_event.event_id,
- )
- return
-
- logger.debug(
- "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s",
- insertion_event,
- marker_event,
- )
-
- await self.store.insert_insertion_extremity(
- insertion_event_id, marker_event.room_id
- )
-
- logger.debug(
- "_handle_marker_event: insertion extremity added for %s from marker event %s",
- insertion_event,
- marker_event,
- )
-
- async def _resync_device(self, sender: str) -> None:
- """We have detected that the device list for the given user may be out
- of sync, so we try and resync them.
- """
-
- try:
- await self.store.mark_remote_user_device_cache_as_stale(sender)
-
- # Immediately attempt a resync in the background
- if self.config.worker_app:
- await self._user_device_resync(user_id=sender)
- else:
- await self._device_list_updater.user_device_resync(sender)
- except Exception:
- logger.exception("Failed to resync device for %s", sender)
-
- @log_function
- async def backfill(
- self, dest: str, room_id: str, limit: int, extremities: List[str]
- ) -> None:
- """Trigger a backfill request to `dest` for the given `room_id`
-
- This will attempt to get more events from the remote. If the other side
- has no new events to offer, this will return an empty list.
-
- As the events are received, we check their signatures, and also do some
- sanity-checking on them. If any of the backfilled events are invalid,
- this method throws a SynapseError.
-
- We might also raise an InvalidResponseError if the response from the remote
- server is just bogus.
-
- TODO: make this more useful to distinguish failures of the remote
- server from invalid events (there is probably no point in trying to
- re-fetch invalid events from every other HS in the room.)
- """
- if dest == self.server_name:
- raise SynapseError(400, "Can't backfill from self.")
-
- events = await self.federation_client.backfill(
- dest, room_id, limit=limit, extremities=extremities
- )
-
- if not events:
- return
-
- # if there are any events in the wrong room, the remote server is buggy and
- # should not be trusted.
- for ev in events:
- if ev.room_id != room_id:
- raise InvalidResponseError(
- f"Remote server {dest} returned event {ev.event_id} which is in "
- f"room {ev.room_id}, when we were backfilling in {room_id}"
- )
-
- await self._process_pulled_events(dest, events, backfilled=True)
-
async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int
) -> bool:
@@ -995,7 +308,7 @@ class FederationHandler(BaseHandler):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
- await self.backfill(
+ await self._federation_event_handler.backfill(
dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
@@ -1086,315 +399,6 @@ class FederationHandler(BaseHandler):
return False
- async def _get_events_and_persist(
- self, destination: str, room_id: str, events: Iterable[str]
- ) -> None:
- """Fetch the given events from a server, and persist them as outliers.
-
- This function *does not* recursively get missing auth events of the
- newly fetched events. Callers must include in the `events` argument
- any missing events from the auth chain.
-
- Logs a warning if we can't find the given event.
- """
-
- room_version = await self.store.get_room_version(room_id)
-
- event_map: Dict[str, EventBase] = {}
-
- async def get_event(event_id: str):
- with nested_logging_context(event_id):
- try:
- event = await self.federation_client.get_pdu(
- [destination],
- event_id,
- room_version,
- outlier=True,
- )
- if event is None:
- logger.warning(
- "Server %s didn't return event %s",
- destination,
- event_id,
- )
- return
-
- event_map[event.event_id] = event
-
- except Exception as e:
- logger.warning(
- "Error fetching missing state/auth event %s: %s %s",
- event_id,
- type(e),
- e,
- )
-
- await concurrently_execute(get_event, events, 5)
-
- # Make a map of auth events for each event. We do this after fetching
- # all the events as some of the events' auth events will be in the list
- # of requested events.
-
- auth_events = [
- aid
- for event in event_map.values()
- for aid in event.auth_event_ids()
- if aid not in event_map
- ]
- persisted_events = await self.store.get_events(
- auth_events,
- allow_rejected=True,
- )
-
- event_infos = []
- for event in event_map.values():
- auth = {}
- for auth_event_id in event.auth_event_ids():
- ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id)
- if ae:
- auth[(ae.type, ae.state_key)] = ae
- else:
- logger.info("Missing auth event %s", auth_event_id)
-
- event_infos.append(_NewEventInfo(event, auth))
-
- if event_infos:
- await self._auth_and_persist_events(
- destination,
- room_id,
- event_infos,
- )
-
- async def _process_pulled_events(
- self, origin: str, events: Iterable[EventBase], backfilled: bool
- ) -> None:
- """Process a batch of events we have pulled from a remote server
-
- Pulls in any events required to auth the events, persists the received events,
- and notifies clients, if appropriate.
-
- Assumes the events have already had their signatures and hashes checked.
-
- Params:
- origin: The server we received these events from
- events: The received events.
- backfilled: True if this is part of a historical batch of events (inhibits
- notification to clients, and validation of device keys.)
- """
-
- # We want to sort these by depth so we process them and
- # tell clients about them in order.
- sorted_events = sorted(events, key=lambda x: x.depth)
-
- for ev in sorted_events:
- with nested_logging_context(ev.event_id):
- await self._process_pulled_event(origin, ev, backfilled=backfilled)
-
- async def _process_pulled_event(
- self, origin: str, event: EventBase, backfilled: bool
- ) -> None:
- """Process a single event that we have pulled from a remote server
-
- Pulls in any events required to auth the event, persists the received event,
- and notifies clients, if appropriate.
-
- Assumes the event has already had its signatures and hashes checked.
-
- This is somewhat equivalent to on_receive_pdu, but applies somewhat different
- logic in the case that we are missing prev_events (in particular, it just
- requests the state at that point, rather than triggering a get_missing_events) -
- so is appropriate when we have pulled the event from a remote server, rather
- than having it pushed to us.
-
- Params:
- origin: The server we received this event from
- events: The received event
- backfilled: True if this is part of a historical batch of events (inhibits
- notification to clients, and validation of device keys.)
- """
- logger.info("Processing pulled event %s", event)
-
- # these should not be outliers.
- assert not event.internal_metadata.is_outlier()
-
- event_id = event.event_id
-
- existing = await self.store.get_event(
- event_id, allow_none=True, allow_rejected=True
- )
- if existing:
- if not existing.internal_metadata.is_outlier():
- logger.info(
- "Ignoring received event %s which we have already seen",
- event_id,
- )
- return
- logger.info("De-outliering event %s", event_id)
-
- try:
- self._sanity_check_event(event)
- except SynapseError as err:
- logger.warning("Event %s failed sanity check: %s", event_id, err)
- return
-
- try:
- state = await self._resolve_state_at_missing_prevs(origin, event)
- await self._process_received_pdu(
- origin, event, state=state, backfilled=backfilled
- )
- except FederationError as e:
- if e.code == 403:
- logger.warning("Pulled event %s failed history check.", event_id)
- else:
- raise
-
- async def _resolve_state_at_missing_prevs(
- self, dest: str, event: EventBase
- ) -> Optional[Iterable[EventBase]]:
- """Calculate the state at an event with missing prev_events.
-
- This is used when we have pulled a batch of events from a remote server, and
- still don't have all the prev_events.
-
- If we already have all the prev_events for `event`, this method does nothing.
-
- Otherwise, the missing prevs become new backwards extremities, and we fall back
- to asking the remote server for the state after each missing `prev_event`,
- and resolving across them.
-
- That's ok provided we then resolve the state against other bits of the DAG
- before using it - in other words, that the received event `event` is not going
- to become the only forwards_extremity in the room (which will ensure that you
- can't just take over a room by sending an event, withholding its prev_events,
- and declaring yourself to be an admin in the subsequent state request).
-
- In other words: we should only call this method if `event` has been *pulled*
- as part of a batch of missing prev events, or similar.
-
- Params:
- dest: the remote server to ask for state at the missing prevs. Typically,
- this will be the server we got `event` from.
- event: an event to check for missing prevs.
-
- Returns:
- if we already had all the prev events, `None`. Otherwise, returns a list of
- the events in the state at `event`.
- """
- room_id = event.room_id
- event_id = event.event_id
-
- prevs = set(event.prev_event_ids())
- seen = await self.store.have_events_in_timeline(prevs)
- missing_prevs = prevs - seen
-
- if not missing_prevs:
- return None
-
- logger.info(
- "Event %s is missing prev_events %s: calculating state for a "
- "backwards extremity",
- event_id,
- shortstr(missing_prevs),
- )
- # Calculate the state after each of the previous events, and
- # resolve them to find the correct state at the current event.
- event_map = {event_id: event}
- try:
- # Get the state of the events we know about
- ours = await self.state_store.get_state_groups_ids(room_id, seen)
-
- # state_maps is a list of mappings from (type, state_key) to event_id
- state_maps: List[StateMap[str]] = list(ours.values())
-
- # we don't need this any more, let's delete it.
- del ours
-
- # Ask the remote server for the states we don't
- # know about
- for p in missing_prevs:
- logger.info("Requesting state after missing prev_event %s", p)
-
- with nested_logging_context(p):
- # note that if any of the missing prevs share missing state or
- # auth events, the requests to fetch those events are deduped
- # by the get_pdu_cache in federation_client.
- remote_state = await self._get_state_after_missing_prev_event(
- dest, room_id, p
- )
-
- remote_state_map = {
- (x.type, x.state_key): x.event_id for x in remote_state
- }
- state_maps.append(remote_state_map)
-
- for x in remote_state:
- event_map[x.event_id] = x
-
- room_version = await self.store.get_room_version_id(room_id)
- state_map = await self._state_resolution_handler.resolve_events_with_store(
- room_id,
- room_version,
- state_maps,
- event_map,
- state_res_store=StateResolutionStore(self.store),
- )
-
- # We need to give _process_received_pdu the actual state events
- # rather than event ids, so generate that now.
-
- # First though we need to fetch all the events that are in
- # state_map, so we can build up the state below.
- evs = await self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
- )
- event_map.update(evs)
-
- state = [event_map[e] for e in state_map.values()]
- except Exception:
- logger.warning(
- "Error attempting to resolve state at missing prev_events",
- exc_info=True,
- )
- raise FederationError(
- "ERROR",
- 403,
- "We can't get valid state history.",
- affected=event_id,
- )
- return state
-
- def _sanity_check_event(self, ev: EventBase) -> None:
- """
- Do some early sanity checks of a received event
-
- In particular, checks it doesn't have an excessive number of
- prev_events or auth_events, which could cause a huge state resolution
- or cascade of event fetches.
-
- Args:
- ev: event to be checked
-
- Raises:
- SynapseError if the event does not pass muster
- """
- if len(ev.prev_event_ids()) > 20:
- logger.warning(
- "Rejecting event %s which has %i prev_events",
- ev.event_id,
- len(ev.prev_event_ids()),
- )
- raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events")
-
- if len(ev.auth_event_ids()) > 10:
- logger.warning(
- "Rejecting event %s which has %i auth_events",
- ev.event_id,
- len(ev.auth_event_ids()),
- )
- raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
-
async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
"""Sends the invite to the remote server for signing.
@@ -1460,9 +464,9 @@ class FederationHandler(BaseHandler):
# This shouldn't happen, because the RoomMemberHandler has a
# linearizer lock which only allows one operation per user per room
# at a time - so this is just paranoia.
- assert room_id not in self.room_queues
+ assert room_id not in self._federation_event_handler.room_queues
- self.room_queues[room_id] = []
+ self._federation_event_handler.room_queues[room_id] = []
await self._clean_room_for_join(room_id)
@@ -1536,8 +540,8 @@ class FederationHandler(BaseHandler):
logger.debug("Finished joining %s to %s", joinee, room_id)
return event.event_id, max_stream_id
finally:
- room_queue = self.room_queues[room_id]
- del self.room_queues[room_id]
+ room_queue = self._federation_event_handler.room_queues[room_id]
+ del self._federation_event_handler.room_queues[room_id]
# we don't need to wait for the queued events to be processed -
# it's just a best-effort thing at this point. We do want to do
@@ -1613,7 +617,7 @@ class FederationHandler(BaseHandler):
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
context = await self.state_handler.compute_event_context(event)
- stream_id = await self.persist_events_and_notify(
+ stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
return event.event_id, stream_id
@@ -1633,7 +637,7 @@ class FederationHandler(BaseHandler):
p,
)
with nested_logging_context(p.event_id):
- await self.on_receive_pdu(origin, p)
+ await self._federation_event_handler.on_receive_pdu(origin, p)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1726,7 +730,7 @@ class FederationHandler(BaseHandler):
raise
# Ensure the user can even join the room.
- await self._check_join_restrictions(context, event)
+ await self._federation_event_handler.check_join_restrictions(context, event)
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
@@ -1803,7 +807,9 @@ class FederationHandler(BaseHandler):
)
context = await self.state_handler.compute_event_context(event)
- await self.persist_events_and_notify(event.room_id, [(event, context)])
+ await self._federation_event_handler.persist_events_and_notify(
+ event.room_id, [(event, context)]
+ )
return event
@@ -1830,7 +836,7 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
- stream_id = await self.persist_events_and_notify(
+ stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -1973,116 +979,6 @@ class FederationHandler(BaseHandler):
return event
- @log_function
- async def on_send_membership_event(
- self, origin: str, event: EventBase
- ) -> Tuple[EventBase, EventContext]:
- """
- We have received a join/leave/knock event for a room via send_join/leave/knock.
-
- Verify that event and send it into the room on the remote homeserver's behalf.
-
- This is quite similar to on_receive_pdu, with the following principal
- differences:
- * only membership events are permitted (and only events with
- sender==state_key -- ie, no kicks or bans)
- * *We* send out the event on behalf of the remote server.
- * We enforce the membership restrictions of restricted rooms.
- * Rejected events result in an exception rather than being stored.
-
- There are also other differences, however it is not clear if these are by
- design or omission. In particular, we do not attempt to backfill any missing
- prev_events.
-
- Args:
- origin: The homeserver of the remote (joining/invited/knocking) user.
- event: The member event that has been signed by the remote homeserver.
-
- Returns:
- The event and context of the event after inserting it into the room graph.
-
- Raises:
- SynapseError if the event is not accepted into the room
- """
- logger.debug(
- "on_send_membership_event: Got event: %s, signatures: %s",
- event.event_id,
- event.signatures,
- )
-
- if get_domain_from_id(event.sender) != origin:
- logger.info(
- "Got send_membership request for user %r from different origin %s",
- event.sender,
- origin,
- )
- raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
-
- if event.sender != event.state_key:
- raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON)
-
- assert not event.internal_metadata.outlier
-
- # Send this event on behalf of the other server.
- #
- # The remote server isn't a full participant in the room at this point, so
- # may not have an up-to-date list of the other homeservers participating in
- # the room, so we send it on their behalf.
- event.internal_metadata.send_on_behalf_of = origin
-
- context = await self.state_handler.compute_event_context(event)
- context = await self._check_event_auth(origin, event, context)
- if context.rejected:
- raise SynapseError(
- 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
- )
-
- # for joins, we need to check the restrictions of restricted rooms
- if event.membership == Membership.JOIN:
- await self._check_join_restrictions(context, event)
-
- # for knock events, we run the third-party event rules. It's not entirely clear
- # why we don't do this for other sorts of membership events.
- if event.membership == Membership.KNOCK:
- event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
- event, context
- )
- if not event_allowed:
- logger.info("Sending of knock %s forbidden by third-party rules", event)
- raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN
- )
-
- # all looks good, we can persist the event.
- await self._run_push_actions_and_persist_event(event, context)
- return event, context
-
- async def _check_join_restrictions(
- self, context: EventContext, event: EventBase
- ) -> None:
- """Check that restrictions in restricted join rules are matched
-
- Called when we receive a join event via send_join.
-
- Raises an auth error if the restrictions are not matched.
- """
- prev_state_ids = await context.get_prev_state_ids()
-
- # Check if the user is already in the room or invited to the room.
- user_id = event.state_key
- prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- prev_member_event = None
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
-
- # Check if the member should be allowed access via membership in a space.
- await self._event_auth_handler.check_restricted_join_rules(
- prev_state_ids,
- event.room_version,
- user_id,
- prev_member_event,
- )
-
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event."""
@@ -2183,126 +1079,6 @@ class FederationHandler(BaseHandler):
else:
return None
- async def get_min_depth_for_context(self, context: str) -> int:
- return await self.store.get_min_depth(context)
-
- async def _auth_and_persist_event(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- state: Optional[Iterable[EventBase]] = None,
- claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
- backfilled: bool = False,
- ) -> None:
- """
- Process an event by performing auth checks and then persisting to the database.
-
- Args:
- origin: The host the event originates from.
- event: The event itself.
- context:
- The event context.
-
- state:
- The state events used to check the event for soft-fail. If this is
- not provided the current state events will be used.
-
- claimed_auth_event_map:
- A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly incomplete, and possibly including events that are not yet
- persisted, or authed, or in the right room.
-
- Only populated where we may not already have persisted these events -
- for example, when populating outliers.
-
- backfilled: True if the event was backfilled.
- """
- context = await self._check_event_auth(
- origin,
- event,
- context,
- state=state,
- claimed_auth_event_map=claimed_auth_event_map,
- backfilled=backfilled,
- )
-
- await self._run_push_actions_and_persist_event(event, context, backfilled)
-
- async def _run_push_actions_and_persist_event(
- self, event: EventBase, context: EventContext, backfilled: bool = False
- ):
- """Run the push actions for a received event, and persist it.
-
- Args:
- event: The event itself.
- context: The event context.
- backfilled: True if the event was backfilled.
- """
- try:
- if (
- not event.internal_metadata.is_outlier()
- and not backfilled
- and not context.rejected
- and (await self.store.get_min_depth(event.room_id)) <= event.depth
- ):
- await self.action_generator.handle_push_actions_for_event(
- event, context
- )
-
- await self.persist_events_and_notify(
- event.room_id, [(event, context)], backfilled=backfilled
- )
- except Exception:
- run_in_background(
- self.store.remove_push_actions_from_staging, event.event_id
- )
- raise
-
- async def _auth_and_persist_events(
- self,
- origin: str,
- room_id: str,
- event_infos: Collection[_NewEventInfo],
- ) -> None:
- """Creates the appropriate contexts and persists events. The events
- should not depend on one another, e.g. this should be used to persist
- a bunch of outliers, but not a chunk of individual events that depend
- on each other for state calculations.
-
- Notifies about the events where appropriate.
- """
-
- if not event_infos:
- return
-
- async def prep(ev_info: _NewEventInfo):
- event = ev_info.event
- with nested_logging_context(suffix=event.event_id):
- res = await self.state_handler.compute_event_context(event)
- res = await self._check_event_auth(
- origin,
- event,
- res,
- claimed_auth_event_map=ev_info.claimed_auth_event_map,
- )
- return res
-
- contexts = await make_deferred_yieldable(
- defer.gatherResults(
- [run_in_background(prep, ev_info) for ev_info in event_infos],
- consumeErrors=True,
- )
- )
-
- await self.persist_events_and_notify(
- room_id,
- [
- (ev_info.event, context)
- for ev_info, context in zip(event_infos, contexts)
- ],
- )
-
async def _persist_auth_tree(
self,
origin: str,
@@ -2400,7 +1176,7 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
if auth_events or state:
- await self.persist_events_and_notify(
+ await self._federation_event_handler.persist_events_and_notify(
room_id,
[
(e, events_to_context[e.event_id])
@@ -2412,108 +1188,10 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
- return await self.persist_events_and_notify(
+ return await self._federation_event_handler.persist_events_and_notify(
room_id, [(event, new_event_context)]
)
- async def _check_for_soft_fail(
- self,
- event: EventBase,
- state: Optional[Iterable[EventBase]],
- backfilled: bool,
- origin: str,
- ) -> None:
- """Checks if we should soft fail the event; if so, marks the event as
- such.
-
- Args:
- event
- state: The state at the event if we don't have all the event's prev events
- backfilled: Whether the event is from backfill
- origin: The host the event originates from.
- """
- # For new (non-backfilled and non-outlier) events we check if the event
- # passes auth based on the current state. If it doesn't then we
- # "soft-fail" the event.
- if backfilled or event.internal_metadata.is_outlier():
- return
-
- extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
- extrem_ids = set(extrem_ids_list)
- prev_event_ids = set(event.prev_event_ids())
-
- if extrem_ids == prev_event_ids:
- # If they're the same then the current state is the same as the
- # state at the event, so no point rechecking auth for soft fail.
- return
-
- room_version = await self.store.get_room_version_id(event.room_id)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
- # Calculate the "current state".
- if state is not None:
- # If we're explicitly given the state then we won't have all the
- # prev events, and so we have a gap in the graph. In this case
- # we want to be a little careful as we might have been down for
- # a while and have an incorrect view of the current state,
- # however we still want to do checks as gaps are easy to
- # maliciously manufacture.
- #
- # So we use a "current state" that is actually a state
- # resolution across the current forward extremities and the
- # given state at the event. This should correctly handle cases
- # like bans, especially with state res v2.
-
- state_sets_d = await self.state_store.get_state_groups(
- event.room_id, extrem_ids
- )
- state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
- state_sets.append(state)
- current_states = await self.state_handler.resolve_events(
- room_version, state_sets, event
- )
- current_state_ids: StateMap[str] = {
- k: e.event_id for k, e in current_states.items()
- }
- else:
- current_state_ids = await self.state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids
- )
-
- logger.debug(
- "Doing soft-fail check for %s: state %s",
- event.event_id,
- current_state_ids,
- )
-
- # Now check if event pass auth against said current state
- auth_types = auth_types_for_event(room_version_obj, event)
- current_state_ids_list = [
- e for k, e in current_state_ids.items() if k in auth_types
- ]
-
- auth_events_map = await self.store.get_events(current_state_ids_list)
- current_auth_events = {
- (e.type, e.state_key): e for e in auth_events_map.values()
- }
-
- try:
- event_auth.check(room_version_obj, event, auth_events=current_auth_events)
- except AuthError as e:
- logger.warning(
- "Soft-failing %r (from %s) because %s",
- event,
- e,
- origin,
- extra={
- "room_id": event.room_id,
- "mxid": event.sender,
- "hs": origin,
- },
- )
- soft_failed_event_counter.inc()
- event.internal_metadata.soft_failed = True
-
async def on_get_missing_events(
self,
origin: str,
@@ -2542,334 +1220,6 @@ class FederationHandler(BaseHandler):
return missing_events
- async def _check_event_auth(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- state: Optional[Iterable[EventBase]] = None,
- claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
- backfilled: bool = False,
- ) -> EventContext:
- """
- Checks whether an event should be rejected (for failing auth checks).
-
- Args:
- origin: The host the event originates from.
- event: The event itself.
- context:
- The event context.
-
- state:
- The state events used to check the event for soft-fail. If this is
- not provided the current state events will be used.
-
- claimed_auth_event_map:
- A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly incomplete, and possibly including events that are not yet
- persisted, or authed, or in the right room.
-
- Only populated where we may not already have persisted these events -
- for example, when populating outliers, or the state for a backwards
- extremity.
-
- backfilled: True if the event was backfilled.
-
- Returns:
- The updated context object.
- """
- room_version = await self.store.get_room_version_id(event.room_id)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
- if claimed_auth_event_map:
- # if we have a copy of the auth events from the event, use that as the
- # basis for auth.
- auth_events = claimed_auth_event_map
- else:
- # otherwise, we calculate what the auth events *should* be, and use that
- prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = self._event_auth_handler.compute_auth_events(
- event, prev_state_ids, for_verification=True
- )
- auth_events_x = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
-
- try:
- (
- context,
- auth_events_for_auth,
- ) = await self._update_auth_events_and_context_for_auth(
- origin, event, context, auth_events
- )
- except Exception:
- # We don't really mind if the above fails, so lets not fail
- # processing if it does. However, it really shouldn't fail so
- # let's still log as an exception since we'll still want to fix
- # any bugs.
- logger.exception(
- "Failed to double check auth events for %s with remote. "
- "Ignoring failure and continuing processing of event.",
- event.event_id,
- )
- auth_events_for_auth = auth_events
-
- try:
- event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth)
- except AuthError as e:
- logger.warning("Failed auth resolution for %r because %s", event, e)
- context.rejected = RejectedReason.AUTH_ERROR
-
- if not context.rejected:
- await self._check_for_soft_fail(event, state, backfilled, origin=origin)
-
- if event.type == EventTypes.GuestAccess and not context.rejected:
- await self.maybe_kick_guest_users(event)
-
- # If we are going to send this event over federation we precaclculate
- # the joined hosts.
- if event.internal_metadata.get_send_on_behalf_of():
- await self.event_creation_handler.cache_joined_hosts_for_event(
- event, context
- )
-
- return context
-
- async def _update_auth_events_and_context_for_auth(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- input_auth_events: StateMap[EventBase],
- ) -> Tuple[EventContext, StateMap[EventBase]]:
- """Helper for _check_event_auth. See there for docs.
-
- Checks whether a given event has the expected auth events. If it
- doesn't then we talk to the remote server to compare state to see if
- we can come to a consensus (e.g. if one server missed some valid
- state).
-
- This attempts to resolve any potential divergence of state between
- servers, but is not essential and so failures should not block further
- processing of the event.
-
- Args:
- origin:
- event:
- context:
-
- input_auth_events:
- Map from (event_type, state_key) to event
-
- Normally, our calculated auth_events based on the state of the room
- at the event's position in the DAG, though occasionally (eg if the
- event is an outlier), may be the auth events claimed by the remote
- server.
-
- Returns:
- updated context, updated auth event map
- """
- # take a copy of input_auth_events before we modify it.
- auth_events: MutableStateMap[EventBase] = dict(input_auth_events)
-
- event_auth_events = set(event.auth_event_ids())
-
- # missing_auth is the set of the event's auth_events which we don't yet have
- # in auth_events.
- missing_auth = event_auth_events.difference(
- e.event_id for e in auth_events.values()
- )
-
- # if we have missing events, we need to fetch those events from somewhere.
- #
- # we start by checking if they are in the store, and then try calling /event_auth/.
- if missing_auth:
- have_events = await self.store.have_seen_events(event.room_id, missing_auth)
- logger.debug("Events %s are in the store", have_events)
- missing_auth.difference_update(have_events)
-
- if missing_auth:
- # If we don't have all the auth events, we need to get them.
- logger.info("auth_events contains unknown events: %s", missing_auth)
- try:
- try:
- remote_auth_chain = await self.federation_client.get_event_auth(
- origin, event.room_id, event.event_id
- )
- except RequestSendFailed as e1:
- # The other side isn't around or doesn't implement the
- # endpoint, so lets just bail out.
- logger.info("Failed to get event auth from remote: %s", e1)
- return context, auth_events
-
- seen_remotes = await self.store.have_seen_events(
- event.room_id, [e.event_id for e in remote_auth_chain]
- )
-
- for e in remote_auth_chain:
- if e.event_id in seen_remotes:
- continue
-
- if e.event_id == event.event_id:
- continue
-
- try:
- auth_ids = e.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in remote_auth_chain
- if e.event_id in auth_ids or e.type == EventTypes.Create
- }
- e.internal_metadata.outlier = True
-
- logger.debug(
- "_check_event_auth %s missing_auth: %s",
- event.event_id,
- e.event_id,
- )
- missing_auth_event_context = (
- await self.state_handler.compute_event_context(e)
- )
- await self._auth_and_persist_event(
- origin,
- e,
- missing_auth_event_context,
- claimed_auth_event_map=auth,
- )
-
- if e.event_id in event_auth_events:
- auth_events[(e.type, e.state_key)] = e
- except AuthError:
- pass
-
- except Exception:
- logger.exception("Failed to get auth chain")
-
- if event.internal_metadata.is_outlier():
- # XXX: given that, for an outlier, we'll be working with the
- # event's *claimed* auth events rather than those we calculated:
- # (a) is there any point in this test, since different_auth below will
- # obviously be empty
- # (b) alternatively, why don't we do it earlier?
- logger.info("Skipping auth_event fetch for outlier")
- return context, auth_events
-
- different_auth = event_auth_events.difference(
- e.event_id for e in auth_events.values()
- )
-
- if not different_auth:
- return context, auth_events
-
- logger.info(
- "auth_events refers to events which are not in our calculated auth "
- "chain: %s",
- different_auth,
- )
-
- # XXX: currently this checks for redactions but I'm not convinced that is
- # necessary?
- different_events = await self.store.get_events_as_list(different_auth)
-
- for d in different_events:
- if d.room_id != event.room_id:
- logger.warning(
- "Event %s refers to auth_event %s which is in a different room",
- event.event_id,
- d.event_id,
- )
-
- # don't attempt to resolve the claimed auth events against our own
- # in this case: just use our own auth events.
- #
- # XXX: should we reject the event in this case? It feels like we should,
- # but then shouldn't we also do so if we've failed to fetch any of the
- # auth events?
- return context, auth_events
-
- # now we state-resolve between our own idea of the auth events, and the remote's
- # idea of them.
-
- local_state = auth_events.values()
- remote_auth_events = dict(auth_events)
- remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
- remote_state = remote_auth_events.values()
-
- room_version = await self.store.get_room_version_id(event.room_id)
- new_state = await self.state_handler.resolve_events(
- room_version, (local_state, remote_state), event
- )
-
- logger.info(
- "After state res: updating auth_events with new state %s",
- {
- (d.type, d.state_key): d.event_id
- for d in new_state.values()
- if auth_events.get((d.type, d.state_key)) != d
- },
- )
-
- auth_events.update(new_state)
-
- context = await self._update_context_for_auth_events(
- event, context, auth_events
- )
-
- return context, auth_events
-
- async def _update_context_for_auth_events(
- self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
- ) -> EventContext:
- """Update the state_ids in an event context after auth event resolution,
- storing the changes as a new state group.
-
- Args:
- event: The event we're handling the context for
-
- context: initial event context
-
- auth_events: Events to update in the event context.
-
- Returns:
- new event context
- """
- # exclude the state key of the new event from the current_state in the context.
- if event.is_state():
- event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
- else:
- event_key = None
- state_updates = {
- k: a.event_id for k, a in auth_events.items() if k != event_key
- }
-
- current_state_ids = await context.get_current_state_ids()
- current_state_ids = dict(current_state_ids) # type: ignore
-
- current_state_ids.update(state_updates)
-
- prev_state_ids = await context.get_prev_state_ids()
- prev_state_ids = dict(prev_state_ids)
-
- prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
-
- # create a new state group as a delta from the existing one.
- prev_group = context.state_group
- state_group = await self.state_store.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=prev_group,
- delta_ids=state_updates,
- current_state_ids=current_state_ids,
- )
-
- return EventContext.with_state(
- state_group=state_group,
- state_group_before_event=context.state_group_before_event,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
- prev_group=prev_group,
- delta_ids=state_updates,
- )
-
async def construct_auth_difference(
self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
) -> Dict:
@@ -3256,99 +1606,6 @@ class FederationHandler(BaseHandler):
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")
- async def persist_events_and_notify(
- self,
- room_id: str,
- event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
- backfilled: bool = False,
- ) -> int:
- """Persists events and tells the notifier/pushers about them, if
- necessary.
-
- Args:
- room_id: The room ID of events being persisted.
- event_and_contexts: Sequence of events with their associated
- context that should be persisted. All events must belong to
- the same room.
- backfilled: Whether these events are a result of
- backfilling or not
-
- Returns:
- The stream ID after which all events have been persisted.
- """
- if not event_and_contexts:
- return self.store.get_current_events_token()
-
- instance = self.config.worker.events_shard_config.get_instance(room_id)
- if instance != self._instance_name:
- # Limit the number of events sent over replication. We choose 200
- # here as that is what we default to in `max_request_body_size(..)`
- for batch in batch_iter(event_and_contexts, 200):
- result = await self._send_events(
- instance_name=instance,
- store=self.store,
- room_id=room_id,
- event_and_contexts=batch,
- backfilled=backfilled,
- )
- return result["max_stream_id"]
- else:
- assert self.storage.persistence
-
- # Note that this returns the events that were persisted, which may not be
- # the same as were passed in if some were deduplicated due to transaction IDs.
- events, max_stream_token = await self.storage.persistence.persist_events(
- event_and_contexts, backfilled=backfilled
- )
-
- if self._ephemeral_messages_enabled:
- for event in events:
- # If there's an expiry timestamp on the event, schedule its expiry.
- self._message_handler.maybe_schedule_expiry(event)
-
- if not backfilled: # Never notify for backfilled events
- for event in events:
- await self._notify_persisted_event(event, max_stream_token)
-
- return max_stream_token.stream
-
- async def _notify_persisted_event(
- self, event: EventBase, max_stream_token: RoomStreamToken
- ) -> None:
- """Checks to see if notifier/pushers should be notified about the
- event or not.
-
- Args:
- event:
- max_stream_id: The max_stream_id returned by persist_events
- """
-
- extra_users = []
- if event.type == EventTypes.Member:
- target_user_id = event.state_key
-
- # We notify for memberships if its an invite for one of our
- # users
- if event.internal_metadata.is_outlier():
- if event.membership != Membership.INVITE:
- if not self.is_mine_id(target_user_id):
- return
-
- target_user = UserID.from_string(target_user_id)
- extra_users.append(target_user)
- elif event.internal_metadata.is_outlier():
- return
-
- # the event has been persisted so it should have a stream ordering.
- assert event.internal_metadata.stream_ordering
-
- event_pos = PersistedEventPosition(
- self._instance_name, event.internal_metadata.stream_ordering
- )
- self.notifier.on_new_room_event(
- event, event_pos, max_stream_token, extra_users=extra_users
- )
-
async def _clean_room_for_join(self, room_id: str) -> None:
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
new file mode 100644
index 0000000000..9f055f00cf
--- /dev/null
+++ b/synapse/handlers/federation_event.py
@@ -0,0 +1,1825 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from http import HTTPStatus
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Container,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+)
+
+import attr
+from prometheus_client import Counter
+
+from twisted.internet import defer
+
+from synapse import event_auth
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RejectedReason,
+ RoomEncryptionAlgorithms,
+)
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ FederationError,
+ HttpResponseException,
+ RequestSendFailed,
+ SynapseError,
+)
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.event_auth import auth_types_for_event
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.federation.federation_client import InvalidResponseError
+from synapse.handlers._base import BaseHandler
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ nested_logging_context,
+ run_in_background,
+)
+from synapse.logging.utils import log_function
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
+from synapse.replication.http.federation import (
+ ReplicationFederationSendEventsRestServlet,
+)
+from synapse.state import StateResolutionStore
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import (
+ MutableStateMap,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StateMap,
+ UserID,
+ get_domain_from_id,
+)
+from synapse.util.async_helpers import Linearizer, concurrently_execute
+from synapse.util.iterutils import batch_iter
+from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.stringutils import shortstr
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+soft_failed_event_counter = Counter(
+ "synapse_federation_soft_failed_events_total",
+ "Events received over federation that we marked as soft_failed",
+)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _NewEventInfo:
+ """Holds information about a received event, ready for passing to _auth_and_persist_events
+
+ Attributes:
+ event: the received event
+
+ claimed_auth_event_map: a map of (type, state_key) => event for the event's
+ claimed auth_events.
+
+ This can include events which have not yet been persisted, in the case that
+ we are backfilling a batch of events.
+
+ Note: May be incomplete: if we were unable to find all of the claimed auth
+ events. Also, treat the contents with caution: the events might also have
+ been rejected, might not yet have been authorized themselves, or they might
+ be in the wrong room.
+
+ """
+
+ event: EventBase
+ claimed_auth_event_map: StateMap[EventBase]
+
+
+class FederationEventHandler(BaseHandler):
+ """Handles events that originated from federation.
+
+ Responsible for handing incoming events and passing them on to the rest
+ of the homeserver (including auth and state conflict resolutions)
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
+
+ self.state_handler = hs.get_state_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self._event_auth_handler = hs.get_event_auth_handler()
+ self._message_handler = hs.get_message_handler()
+ self.action_generator = hs.get_action_generator()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
+
+ self.federation_client = hs.get_federation_client()
+ self.third_party_event_rules = hs.get_third_party_event_rules()
+
+ self.is_mine_id = hs.is_mine_id
+ self._instance_name = hs.get_instance_name()
+
+ self.config = hs.config
+ self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
+
+ self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
+ if hs.config.worker_app:
+ self._user_device_resync = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
+ )
+ else:
+ self._device_list_updater = hs.get_device_handler().device_list_updater
+
+ # When joining a room we need to queue any events for that room up.
+ # For each room, a list of (pdu, origin) tuples.
+ # TODO: replace this with something more elegant, probably based around the
+ # federation event staging area.
+ self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
+
+ self._room_pdu_linearizer = Linearizer("fed_room_pdu")
+
+ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
+ """Process a PDU received via a federation /send/ transaction
+
+ Args:
+ origin: server which initiated the /send/ transaction. Will
+ be used to fetch missing events or state.
+ pdu: received PDU
+ """
+
+ room_id = pdu.room_id
+ event_id = pdu.event_id
+
+ # We reprocess pdus when we have seen them only as outliers
+ existing = await self.store.get_event(
+ event_id, allow_none=True, allow_rejected=True
+ )
+
+ # FIXME: Currently we fetch an event again when we already have it
+ # if it has been marked as an outlier.
+ if existing:
+ if not existing.internal_metadata.is_outlier():
+ logger.info(
+ "Ignoring received event %s which we have already seen", event_id
+ )
+ return
+ if pdu.internal_metadata.is_outlier():
+ logger.info(
+ "Ignoring received outlier %s which we already have as an outlier",
+ event_id,
+ )
+ return
+ logger.info("De-outliering event %s", event_id)
+
+ # do some initial sanity-checking of the event. In particular, make
+ # sure it doesn't have hundreds of prev_events or auth_events, which
+ # could cause a huge state resolution or cascade of event fetches.
+ try:
+ self._sanity_check_event(pdu)
+ except SynapseError as err:
+ logger.warning("Received event failed sanity checks")
+ raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
+
+ # If we are currently in the process of joining this room, then we
+ # queue up events for later processing.
+ if room_id in self.room_queues:
+ logger.info(
+ "Queuing PDU from %s for now: join in progress",
+ origin,
+ )
+ self.room_queues[room_id].append((pdu, origin))
+ return
+
+ # If we're not in the room just ditch the event entirely. This is
+ # probably an old server that has come back and thinks we're still in
+ # the room (or we've been rejoined to the room by a state reset).
+ #
+ # Note that if we were never in the room then we would have already
+ # dropped the event, since we wouldn't know the room version.
+ is_in_room = await self._event_auth_handler.check_host_in_room(
+ room_id, self.server_name
+ )
+ if not is_in_room:
+ logger.info(
+ "Ignoring PDU from %s as we're not in the room",
+ origin,
+ )
+ return None
+
+ # Check that the event passes auth based on the state at the event. This is
+ # done for events that are to be added to the timeline (non-outliers).
+ #
+ # Get missing pdus if necessary:
+ # - Fetching any missing prev events to fill in gaps in the graph
+ # - Fetching state if we have a hole in the graph
+ if not pdu.internal_metadata.is_outlier():
+ prevs = set(pdu.prev_event_ids())
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if missing_prevs:
+ # We only backfill backwards to the min depth.
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
+ logger.debug("min_depth: %d", min_depth)
+
+ if min_depth is not None and pdu.depth > min_depth:
+ # If we're missing stuff, ensure we only fetch stuff one
+ # at a time.
+ logger.info(
+ "Acquiring room lock to fetch %d missing prev_events: %s",
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
+ logger.info(
+ "Acquired room lock to fetch %d missing prev_events",
+ len(missing_prevs),
+ )
+
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ ) from e
+
+ # Update the set of things we've seen after trying to
+ # fetch the missing stuff
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if not missing_prevs:
+ logger.info("Found all missing prev_events")
+
+ if missing_prevs:
+ # since this event was pushed to us, it is possible for it to
+ # become the only forward-extremity in the room, and we would then
+ # trust its state to be the state for the whole room. This is very
+ # bad. Further, if the event was pushed to us, there is no excuse
+ # for us not to have all the prev_events. (XXX: apart from
+ # min_depth?)
+ #
+ # We therefore reject any such events.
+ logger.warning(
+ "Rejecting: failed to fetch %d prev events: %s",
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
+ raise FederationError(
+ "ERROR",
+ 403,
+ (
+ "Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ affected=pdu.event_id,
+ )
+
+ await self._process_received_pdu(origin, pdu, state=None)
+
+ @log_function
+ async def on_send_membership_event(
+ self, origin: str, event: EventBase
+ ) -> Tuple[EventBase, EventContext]:
+ """
+ We have received a join/leave/knock event for a room via send_join/leave/knock.
+
+ Verify that event and send it into the room on the remote homeserver's behalf.
+
+ This is quite similar to on_receive_pdu, with the following principal
+ differences:
+ * only membership events are permitted (and only events with
+ sender==state_key -- ie, no kicks or bans)
+ * *We* send out the event on behalf of the remote server.
+ * We enforce the membership restrictions of restricted rooms.
+ * Rejected events result in an exception rather than being stored.
+
+ There are also other differences, however it is not clear if these are by
+ design or omission. In particular, we do not attempt to backfill any missing
+ prev_events.
+
+ Args:
+ origin: The homeserver of the remote (joining/invited/knocking) user.
+ event: The member event that has been signed by the remote homeserver.
+
+ Returns:
+ The event and context of the event after inserting it into the room graph.
+
+ Raises:
+ SynapseError if the event is not accepted into the room
+ """
+ logger.debug(
+ "on_send_membership_event: Got event: %s, signatures: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ if get_domain_from_id(event.sender) != origin:
+ logger.info(
+ "Got send_membership request for user %r from different origin %s",
+ event.sender,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ if event.sender != event.state_key:
+ raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON)
+
+ assert not event.internal_metadata.outlier
+
+ # Send this event on behalf of the other server.
+ #
+ # The remote server isn't a full participant in the room at this point, so
+ # may not have an up-to-date list of the other homeservers participating in
+ # the room, so we send it on their behalf.
+ event.internal_metadata.send_on_behalf_of = origin
+
+ context = await self.state_handler.compute_event_context(event)
+ context = await self._check_event_auth(origin, event, context)
+ if context.rejected:
+ raise SynapseError(
+ 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
+ )
+
+ # for joins, we need to check the restrictions of restricted rooms
+ if event.membership == Membership.JOIN:
+ await self.check_join_restrictions(context, event)
+
+ # for knock events, we run the third-party event rules. It's not entirely clear
+ # why we don't do this for other sorts of membership events.
+ if event.membership == Membership.KNOCK:
+ event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info("Sending of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ # all looks good, we can persist the event.
+ await self._run_push_actions_and_persist_event(event, context)
+ return event, context
+
+ async def check_join_restrictions(
+ self, context: EventContext, event: EventBase
+ ) -> None:
+ """Check that restrictions in restricted join rules are matched
+
+ Called when we receive a join event via send_join.
+
+ Raises an auth error if the restrictions are not matched.
+ """
+ prev_state_ids = await context.get_prev_state_ids()
+
+ # Check if the user is already in the room or invited to the room.
+ user_id = event.state_key
+ prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
+ prev_member_event = None
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+
+ # Check if the member should be allowed access via membership in a space.
+ await self._event_auth_handler.check_restricted_join_rules(
+ prev_state_ids,
+ event.room_version,
+ user_id,
+ prev_member_event,
+ )
+
+ @log_function
+ async def backfill(
+ self, dest: str, room_id: str, limit: int, extremities: List[str]
+ ) -> None:
+ """Trigger a backfill request to `dest` for the given `room_id`
+
+ This will attempt to get more events from the remote. If the other side
+ has no new events to offer, this will return an empty list.
+
+ As the events are received, we check their signatures, and also do some
+ sanity-checking on them. If any of the backfilled events are invalid,
+ this method throws a SynapseError.
+
+ We might also raise an InvalidResponseError if the response from the remote
+ server is just bogus.
+
+ TODO: make this more useful to distinguish failures of the remote
+ server from invalid events (there is probably no point in trying to
+ re-fetch invalid events from every other HS in the room.)
+ """
+ if dest == self.server_name:
+ raise SynapseError(400, "Can't backfill from self.")
+
+ events = await self.federation_client.backfill(
+ dest, room_id, limit=limit, extremities=extremities
+ )
+
+ if not events:
+ return
+
+ # if there are any events in the wrong room, the remote server is buggy and
+ # should not be trusted.
+ for ev in events:
+ if ev.room_id != room_id:
+ raise InvalidResponseError(
+ f"Remote server {dest} returned event {ev.event_id} which is in "
+ f"room {ev.room_id}, when we were backfilling in {room_id}"
+ )
+
+ await self._process_pulled_events(dest, events, backfilled=True)
+
+ async def _get_missing_events_for_pdu(
+ self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
+ ) -> None:
+ """
+ Args:
+ origin: Origin of the pdu. Will be called to get the missing events
+ pdu: received pdu
+ prevs: List of event ids which we are missing
+ min_depth: Minimum depth of events to return.
+ """
+
+ room_id = pdu.room_id
+ event_id = pdu.event_id
+
+ seen = await self.store.have_events_in_timeline(prevs)
+
+ if not prevs - seen:
+ return
+
+ latest_list = await self.store.get_latest_event_ids_in_room(room_id)
+
+ # We add the prev events that we have seen to the latest
+ # list to ensure the remote server doesn't give them to us
+ latest = set(latest_list)
+ latest |= seen
+
+ logger.info(
+ "Requesting missing events between %s and %s",
+ shortstr(latest),
+ event_id,
+ )
+
+ # XXX: we set timeout to 10s to help workaround
+ # https://github.com/matrix-org/synapse/issues/1733.
+ # The reason is to avoid holding the linearizer lock
+ # whilst processing inbound /send transactions, causing
+ # FDs to stack up and block other inbound transactions
+ # which empirically can currently take up to 30 minutes.
+ #
+ # N.B. this explicitly disables retry attempts.
+ #
+ # N.B. this also increases our chances of falling back to
+ # fetching fresh state for the room if the missing event
+ # can't be found, which slightly reduces our security.
+ # it may also increase our DAG extremity count for the room,
+ # causing additional state resolution? See #1760.
+ # However, fetching state doesn't hold the linearizer lock
+ # apparently.
+ #
+ # see https://github.com/matrix-org/synapse/pull/1744
+ #
+ # ----
+ #
+ # Update richvdh 2018/09/18: There are a number of problems with timing this
+ # request out aggressively on the client side:
+ #
+ # - it plays badly with the server-side rate-limiter, which starts tarpitting you
+ # if you send too many requests at once, so you end up with the server carefully
+ # working through the backlog of your requests, which you have already timed
+ # out.
+ #
+ # - for this request in particular, we now (as of
+ # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the
+ # server can't produce a plausible-looking set of prev_events - so we becone
+ # much more likely to reject the event.
+ #
+ # - contrary to what it says above, we do *not* fall back to fetching fresh state
+ # for the room if get_missing_events times out. Rather, we give up processing
+ # the PDU whose prevs we are missing, which then makes it much more likely that
+ # we'll end up back here for the *next* PDU in the list, which exacerbates the
+ # problem.
+ #
+ # - the aggressive 10s timeout was introduced to deal with incoming federation
+ # requests taking 8 hours to process. It's not entirely clear why that was going
+ # on; certainly there were other issues causing traffic storms which are now
+ # resolved, and I think in any case we may be more sensible about our locking
+ # now. We're *certainly* more sensible about our logging.
+ #
+ # All that said: Let's try increasing the timeout to 60s and see what happens.
+
+ try:
+ missing_events = await self.federation_client.get_missing_events(
+ origin,
+ room_id,
+ earliest_events_ids=list(latest),
+ latest_events=[pdu],
+ limit=10,
+ min_depth=min_depth,
+ timeout=60000,
+ )
+ except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e:
+ # We failed to get the missing events, but since we need to handle
+ # the case of `get_missing_events` not returning the necessary
+ # events anyway, it is safe to simply log the error and continue.
+ logger.warning("Failed to get prev_events: %s", e)
+ return
+
+ logger.info("Got %d prev_events", len(missing_events))
+ await self._process_pulled_events(origin, missing_events, backfilled=False)
+
+ async def _process_pulled_events(
+ self, origin: str, events: Iterable[EventBase], backfilled: bool
+ ) -> None:
+ """Process a batch of events we have pulled from a remote server
+
+ Pulls in any events required to auth the events, persists the received events,
+ and notifies clients, if appropriate.
+
+ Assumes the events have already had their signatures and hashes checked.
+
+ Params:
+ origin: The server we received these events from
+ events: The received events.
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
+ """
+
+ # We want to sort these by depth so we process them and
+ # tell clients about them in order.
+ sorted_events = sorted(events, key=lambda x: x.depth)
+
+ for ev in sorted_events:
+ with nested_logging_context(ev.event_id):
+ await self._process_pulled_event(origin, ev, backfilled=backfilled)
+
+ async def _process_pulled_event(
+ self, origin: str, event: EventBase, backfilled: bool
+ ) -> None:
+ """Process a single event that we have pulled from a remote server
+
+ Pulls in any events required to auth the event, persists the received event,
+ and notifies clients, if appropriate.
+
+ Assumes the event has already had its signatures and hashes checked.
+
+ This is somewhat equivalent to on_receive_pdu, but applies somewhat different
+ logic in the case that we are missing prev_events (in particular, it just
+ requests the state at that point, rather than triggering a get_missing_events) -
+ so is appropriate when we have pulled the event from a remote server, rather
+ than having it pushed to us.
+
+ Params:
+ origin: The server we received this event from
+ events: The received event
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
+ """
+ logger.info("Processing pulled event %s", event)
+
+ # these should not be outliers.
+ assert not event.internal_metadata.is_outlier()
+
+ event_id = event.event_id
+
+ existing = await self.store.get_event(
+ event_id, allow_none=True, allow_rejected=True
+ )
+ if existing:
+ if not existing.internal_metadata.is_outlier():
+ logger.info(
+ "Ignoring received event %s which we have already seen",
+ event_id,
+ )
+ return
+ logger.info("De-outliering event %s", event_id)
+
+ try:
+ self._sanity_check_event(event)
+ except SynapseError as err:
+ logger.warning("Event %s failed sanity check: %s", event_id, err)
+ return
+
+ try:
+ state = await self._resolve_state_at_missing_prevs(origin, event)
+ await self._process_received_pdu(
+ origin, event, state=state, backfilled=backfilled
+ )
+ except FederationError as e:
+ if e.code == 403:
+ logger.warning("Pulled event %s failed history check.", event_id)
+ else:
+ raise
+
+ async def _resolve_state_at_missing_prevs(
+ self, dest: str, event: EventBase
+ ) -> Optional[Iterable[EventBase]]:
+ """Calculate the state at an event with missing prev_events.
+
+ This is used when we have pulled a batch of events from a remote server, and
+ still don't have all the prev_events.
+
+ If we already have all the prev_events for `event`, this method does nothing.
+
+ Otherwise, the missing prevs become new backwards extremities, and we fall back
+ to asking the remote server for the state after each missing `prev_event`,
+ and resolving across them.
+
+ That's ok provided we then resolve the state against other bits of the DAG
+ before using it - in other words, that the received event `event` is not going
+ to become the only forwards_extremity in the room (which will ensure that you
+ can't just take over a room by sending an event, withholding its prev_events,
+ and declaring yourself to be an admin in the subsequent state request).
+
+ In other words: we should only call this method if `event` has been *pulled*
+ as part of a batch of missing prev events, or similar.
+
+ Params:
+ dest: the remote server to ask for state at the missing prevs. Typically,
+ this will be the server we got `event` from.
+ event: an event to check for missing prevs.
+
+ Returns:
+ if we already had all the prev events, `None`. Otherwise, returns a list of
+ the events in the state at `event`.
+ """
+ room_id = event.room_id
+ event_id = event.event_id
+
+ prevs = set(event.prev_event_ids())
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if not missing_prevs:
+ return None
+
+ logger.info(
+ "Event %s is missing prev_events %s: calculating state for a "
+ "backwards extremity",
+ event_id,
+ shortstr(missing_prevs),
+ )
+ # Calculate the state after each of the previous events, and
+ # resolve them to find the correct state at the current event.
+ event_map = {event_id: event}
+ try:
+ # Get the state of the events we know about
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
+
+ # state_maps is a list of mappings from (type, state_key) to event_id
+ state_maps: List[StateMap[str]] = list(ours.values())
+
+ # we don't need this any more, let's delete it.
+ del ours
+
+ # Ask the remote server for the states we don't
+ # know about
+ for p in missing_prevs:
+ logger.info("Requesting state after missing prev_event %s", p)
+
+ with nested_logging_context(p):
+ # note that if any of the missing prevs share missing state or
+ # auth events, the requests to fetch those events are deduped
+ # by the get_pdu_cache in federation_client.
+ remote_state = await self._get_state_after_missing_prev_event(
+ dest, room_id, p
+ )
+
+ remote_state_map = {
+ (x.type, x.state_key): x.event_id for x in remote_state
+ }
+ state_maps.append(remote_state_map)
+
+ for x in remote_state:
+ event_map[x.event_id] = x
+
+ room_version = await self.store.get_room_version_id(room_id)
+ state_map = await self._state_resolution_handler.resolve_events_with_store(
+ room_id,
+ room_version,
+ state_maps,
+ event_map,
+ state_res_store=StateResolutionStore(self.store),
+ )
+
+ # We need to give _process_received_pdu the actual state events
+ # rather than event ids, so generate that now.
+
+ # First though we need to fetch all the events that are in
+ # state_map, so we can build up the state below.
+ evs = await self.store.get_events(
+ list(state_map.values()),
+ get_prev_content=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+ event_map.update(evs)
+
+ state = [event_map[e] for e in state_map.values()]
+ except Exception:
+ logger.warning(
+ "Error attempting to resolve state at missing prev_events",
+ exc_info=True,
+ )
+ raise FederationError(
+ "ERROR",
+ 403,
+ "We can't get valid state history.",
+ affected=event_id,
+ )
+ return state
+
+ async def _get_state_after_missing_prev_event(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ ) -> List[EventBase]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+
+ Returns:
+ A list of events in the state, including the event itself
+ """
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
+
+ logger.debug(
+ "state_ids returned %i state events, %i auth events",
+ len(state_event_ids),
+ len(auth_event_ids),
+ )
+
+ # start by just trying to fetch the events from the store
+ desired_events = set(state_event_ids)
+ desired_events.add(event_id)
+ logger.debug("Fetching %i events from cache/store", len(desired_events))
+ fetched_events = await self.store.get_events(
+ desired_events, allow_rejected=True
+ )
+
+ missing_desired_events = desired_events - fetched_events.keys()
+ logger.debug(
+ "We are missing %i events (got %i)",
+ len(missing_desired_events),
+ len(fetched_events),
+ )
+
+ # We probably won't need most of the auth events, so let's just check which
+ # we have for now, rather than thrashing the event cache with them all
+ # unnecessarily.
+
+ # TODO: we probably won't actually need all of the auth events, since we
+ # already have a bunch of the state events. It would be nice if the
+ # federation api gave us a way of finding out which we actually need.
+
+ missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+ missing_auth_events.difference_update(
+ await self.store.have_seen_events(room_id, missing_auth_events)
+ )
+ logger.debug("We are also missing %i auth events", len(missing_auth_events))
+
+ missing_events = missing_desired_events | missing_auth_events
+ logger.debug("Fetching %i events from remote", len(missing_events))
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
+ )
+
+ # we need to make sure we re-load from the database to get the rejected
+ # state correct.
+ fetched_events.update(
+ await self.store.get_events(missing_desired_events, allow_rejected=True)
+ )
+
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = [
+ (event_id, event.room_id)
+ for event_id, event in fetched_events.items()
+ if event.room_id != room_id
+ ]
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+
+ del fetched_events[bad_event_id]
+
+ # if we couldn't get the prev event in question, that's a problem.
+ remote_event = fetched_events.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+
+ # missing state at that event is a warning, not a blocker
+ # XXX: this doesn't sound right? it means that we'll end up with incomplete
+ # state.
+ failed_to_fetch = desired_events - fetched_events.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
+
+ remote_state = [
+ fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
+ ]
+
+ if remote_event.is_state() and remote_event.rejected_reason is None:
+ remote_state.append(remote_event)
+
+ return remote_state
+
+ async def _process_received_pdu(
+ self,
+ origin: str,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
+ backfilled: bool = False,
+ ) -> None:
+ """Called when we have a new pdu. We need to do auth checks and put it
+ through the StateHandler.
+
+ Args:
+ origin: server sending the event
+
+ event: event to be persisted
+
+ state: Normally None, but if we are handling a gap in the graph
+ (ie, we are missing one or more prev_events), the resolved state at the
+ event
+
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
+ """
+ logger.debug("Processing event: %s", event)
+
+ try:
+ context = await self.state_handler.compute_event_context(
+ event, old_state=state
+ )
+ await self._auth_and_persist_event(
+ origin, event, context, state=state, backfilled=backfilled
+ )
+ except AuthError as e:
+ raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
+
+ if backfilled:
+ return
+
+ # For encrypted messages we check that we know about the sending device,
+ # if we don't then we mark the device cache for that user as stale.
+ if event.type == EventTypes.Encrypted:
+ device_id = event.content.get("device_id")
+ sender_key = event.content.get("sender_key")
+
+ cached_devices = await self.store.get_cached_devices_for_user(event.sender)
+
+ resync = False # Whether we should resync device lists.
+
+ device = None
+ if device_id is not None:
+ device = cached_devices.get(device_id)
+ if device is None:
+ logger.info(
+ "Received event from remote device not in our cache: %s %s",
+ event.sender,
+ device_id,
+ )
+ resync = True
+
+ # We also check if the `sender_key` matches what we expect.
+ if sender_key is not None:
+ # Figure out what sender key we're expecting. If we know the
+ # device and recognize the algorithm then we can work out the
+ # exact key to expect. Otherwise check it matches any key we
+ # have for that device.
+
+ current_keys: Container[str] = []
+
+ if device:
+ keys = device.get("keys", {}).get("keys", {})
+
+ if (
+ event.content.get("algorithm")
+ == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2
+ ):
+ # For this algorithm we expect a curve25519 key.
+ key_name = "curve25519:%s" % (device_id,)
+ current_keys = [keys.get(key_name)]
+ else:
+ # We don't know understand the algorithm, so we just
+ # check it matches a key for the device.
+ current_keys = keys.values()
+ elif device_id:
+ # We don't have any keys for the device ID.
+ pass
+ else:
+ # The event didn't include a device ID, so we just look for
+ # keys across all devices.
+ current_keys = [
+ key
+ for device in cached_devices.values()
+ for key in device.get("keys", {}).get("keys", {}).values()
+ ]
+
+ # We now check that the sender key matches (one of) the expected
+ # keys.
+ if sender_key not in current_keys:
+ logger.info(
+ "Received event from remote device with unexpected sender key: %s %s: %s",
+ event.sender,
+ device_id or "<no device_id>",
+ sender_key,
+ )
+ resync = True
+
+ if resync:
+ run_as_background_process(
+ "resync_device_due_to_pdu",
+ self._resync_device,
+ event.sender,
+ )
+
+ await self._handle_marker_event(origin, event)
+
+ async def _resync_device(self, sender: str) -> None:
+ """We have detected that the device list for the given user may be out
+ of sync, so we try and resync them.
+ """
+
+ try:
+ await self.store.mark_remote_user_device_cache_as_stale(sender)
+
+ # Immediately attempt a resync in the background
+ if self.config.worker_app:
+ await self._user_device_resync(user_id=sender)
+ else:
+ await self._device_list_updater.user_device_resync(sender)
+ except Exception:
+ logger.exception("Failed to resync device for %s", sender)
+
+ async def _handle_marker_event(self, origin: str, marker_event: EventBase):
+ """Handles backfilling the insertion event when we receive a marker
+ event that points to one.
+
+ Args:
+ origin: Origin of the event. Will be called to get the insertion event
+ marker_event: The event to process
+ """
+
+ if marker_event.type != EventTypes.MSC2716_MARKER:
+ # Not a marker event
+ return
+
+ if marker_event.rejected_reason is not None:
+ # Rejected event
+ return
+
+ # Skip processing a marker event if the room version doesn't
+ # support it.
+ room_version = await self.store.get_room_version(marker_event.room_id)
+ if not room_version.msc2716_historical:
+ return
+
+ logger.debug("_handle_marker_event: received %s", marker_event)
+
+ insertion_event_id = marker_event.content.get(
+ EventContentFields.MSC2716_MARKER_INSERTION
+ )
+
+ if insertion_event_id is None:
+ # Nothing to retrieve then (invalid marker)
+ return
+
+ logger.debug(
+ "_handle_marker_event: backfilling insertion event %s", insertion_event_id
+ )
+
+ await self._get_events_and_persist(
+ origin,
+ marker_event.room_id,
+ [insertion_event_id],
+ )
+
+ insertion_event = await self.store.get_event(
+ insertion_event_id, allow_none=True
+ )
+ if insertion_event is None:
+ logger.warning(
+ "_handle_marker_event: server %s didn't return insertion event %s for marker %s",
+ origin,
+ insertion_event_id,
+ marker_event.event_id,
+ )
+ return
+
+ logger.debug(
+ "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s",
+ insertion_event,
+ marker_event,
+ )
+
+ await self.store.insert_insertion_extremity(
+ insertion_event_id, marker_event.room_id
+ )
+
+ logger.debug(
+ "_handle_marker_event: insertion extremity added for %s from marker event %s",
+ insertion_event,
+ marker_event,
+ )
+
+ async def _get_events_and_persist(
+ self, destination: str, room_id: str, events: Iterable[str]
+ ) -> None:
+ """Fetch the given events from a server, and persist them as outliers.
+
+ This function *does not* recursively get missing auth events of the
+ newly fetched events. Callers must include in the `events` argument
+ any missing events from the auth chain.
+
+ Logs a warning if we can't find the given event.
+ """
+
+ room_version = await self.store.get_room_version(room_id)
+
+ event_map: Dict[str, EventBase] = {}
+
+ async def get_event(event_id: str):
+ with nested_logging_context(event_id):
+ try:
+ event = await self.federation_client.get_pdu(
+ [destination],
+ event_id,
+ room_version,
+ outlier=True,
+ )
+ if event is None:
+ logger.warning(
+ "Server %s didn't return event %s",
+ destination,
+ event_id,
+ )
+ return
+
+ event_map[event.event_id] = event
+
+ except Exception as e:
+ logger.warning(
+ "Error fetching missing state/auth event %s: %s %s",
+ event_id,
+ type(e),
+ e,
+ )
+
+ await concurrently_execute(get_event, events, 5)
+
+ # Make a map of auth events for each event. We do this after fetching
+ # all the events as some of the events' auth events will be in the list
+ # of requested events.
+
+ auth_events = [
+ aid
+ for event in event_map.values()
+ for aid in event.auth_event_ids()
+ if aid not in event_map
+ ]
+ persisted_events = await self.store.get_events(
+ auth_events,
+ allow_rejected=True,
+ )
+
+ event_infos = []
+ for event in event_map.values():
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id)
+ if ae:
+ auth[(ae.type, ae.state_key)] = ae
+ else:
+ logger.info("Missing auth event %s", auth_event_id)
+
+ event_infos.append(_NewEventInfo(event, auth))
+
+ if event_infos:
+ await self._auth_and_persist_events(
+ destination,
+ room_id,
+ event_infos,
+ )
+
+ async def _auth_and_persist_events(
+ self,
+ origin: str,
+ room_id: str,
+ event_infos: Collection[_NewEventInfo],
+ ) -> None:
+ """Creates the appropriate contexts and persists events. The events
+ should not depend on one another, e.g. this should be used to persist
+ a bunch of outliers, but not a chunk of individual events that depend
+ on each other for state calculations.
+
+ Notifies about the events where appropriate.
+ """
+
+ if not event_infos:
+ return
+
+ async def prep(ev_info: _NewEventInfo):
+ event = ev_info.event
+ with nested_logging_context(suffix=event.event_id):
+ res = await self.state_handler.compute_event_context(event)
+ res = await self._check_event_auth(
+ origin,
+ event,
+ res,
+ claimed_auth_event_map=ev_info.claimed_auth_event_map,
+ )
+ return res
+
+ contexts = await make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(prep, ev_info) for ev_info in event_infos],
+ consumeErrors=True,
+ )
+ )
+
+ await self.persist_events_and_notify(
+ room_id,
+ [
+ (ev_info.event, context)
+ for ev_info, context in zip(event_infos, contexts)
+ ],
+ )
+
+ async def _auth_and_persist_event(
+ self,
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ state: Optional[Iterable[EventBase]] = None,
+ claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
+ backfilled: bool = False,
+ ) -> None:
+ """
+ Process an event by performing auth checks and then persisting to the database.
+
+ Args:
+ origin: The host the event originates from.
+ event: The event itself.
+ context:
+ The event context.
+
+ state:
+ The state events used to check the event for soft-fail. If this is
+ not provided the current state events will be used.
+
+ claimed_auth_event_map:
+ A map of (type, state_key) => event for the event's claimed auth_events.
+ Possibly incomplete, and possibly including events that are not yet
+ persisted, or authed, or in the right room.
+
+ Only populated where we may not already have persisted these events -
+ for example, when populating outliers.
+
+ backfilled: True if the event was backfilled.
+ """
+ context = await self._check_event_auth(
+ origin,
+ event,
+ context,
+ state=state,
+ claimed_auth_event_map=claimed_auth_event_map,
+ backfilled=backfilled,
+ )
+
+ await self._run_push_actions_and_persist_event(event, context, backfilled)
+
+ async def _check_event_auth(
+ self,
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ state: Optional[Iterable[EventBase]] = None,
+ claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
+ backfilled: bool = False,
+ ) -> EventContext:
+ """
+ Checks whether an event should be rejected (for failing auth checks).
+
+ Args:
+ origin: The host the event originates from.
+ event: The event itself.
+ context:
+ The event context.
+
+ state:
+ The state events used to check the event for soft-fail. If this is
+ not provided the current state events will be used.
+
+ claimed_auth_event_map:
+ A map of (type, state_key) => event for the event's claimed auth_events.
+ Possibly incomplete, and possibly including events that are not yet
+ persisted, or authed, or in the right room.
+
+ Only populated where we may not already have persisted these events -
+ for example, when populating outliers, or the state for a backwards
+ extremity.
+
+ backfilled: True if the event was backfilled.
+
+ Returns:
+ The updated context object.
+ """
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ if claimed_auth_event_map:
+ # if we have a copy of the auth events from the event, use that as the
+ # basis for auth.
+ auth_events = claimed_auth_event_map
+ else:
+ # otherwise, we calculate what the auth events *should* be, and use that
+ prev_state_ids = await context.get_prev_state_ids()
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
+ event, prev_state_ids, for_verification=True
+ )
+ auth_events_x = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
+
+ try:
+ (
+ context,
+ auth_events_for_auth,
+ ) = await self._update_auth_events_and_context_for_auth(
+ origin, event, context, auth_events
+ )
+ except Exception:
+ # We don't really mind if the above fails, so lets not fail
+ # processing if it does. However, it really shouldn't fail so
+ # let's still log as an exception since we'll still want to fix
+ # any bugs.
+ logger.exception(
+ "Failed to double check auth events for %s with remote. "
+ "Ignoring failure and continuing processing of event.",
+ event.event_id,
+ )
+ auth_events_for_auth = auth_events
+
+ try:
+ event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth)
+ except AuthError as e:
+ logger.warning("Failed auth resolution for %r because %s", event, e)
+ context.rejected = RejectedReason.AUTH_ERROR
+
+ if not context.rejected:
+ await self._check_for_soft_fail(event, state, backfilled, origin=origin)
+
+ if event.type == EventTypes.GuestAccess and not context.rejected:
+ await self.maybe_kick_guest_users(event)
+
+ # If we are going to send this event over federation we precaclculate
+ # the joined hosts.
+ if event.internal_metadata.get_send_on_behalf_of():
+ await self.event_creation_handler.cache_joined_hosts_for_event(
+ event, context
+ )
+
+ return context
+
+ async def _check_for_soft_fail(
+ self,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
+ backfilled: bool,
+ origin: str,
+ ) -> None:
+ """Checks if we should soft fail the event; if so, marks the event as
+ such.
+
+ Args:
+ event
+ state: The state at the event if we don't have all the event's prev events
+ backfilled: Whether the event is from backfill
+ origin: The host the event originates from.
+ """
+ # For new (non-backfilled and non-outlier) events we check if the event
+ # passes auth based on the current state. If it doesn't then we
+ # "soft-fail" the event.
+ if backfilled or event.internal_metadata.is_outlier():
+ return
+
+ extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
+ extrem_ids = set(extrem_ids_list)
+ prev_event_ids = set(event.prev_event_ids())
+
+ if extrem_ids == prev_event_ids:
+ # If they're the same then the current state is the same as the
+ # state at the event, so no point rechecking auth for soft fail.
+ return
+
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ # Calculate the "current state".
+ if state is not None:
+ # If we're explicitly given the state then we won't have all the
+ # prev events, and so we have a gap in the graph. In this case
+ # we want to be a little careful as we might have been down for
+ # a while and have an incorrect view of the current state,
+ # however we still want to do checks as gaps are easy to
+ # maliciously manufacture.
+ #
+ # So we use a "current state" that is actually a state
+ # resolution across the current forward extremities and the
+ # given state at the event. This should correctly handle cases
+ # like bans, especially with state res v2.
+
+ state_sets_d = await self.state_store.get_state_groups(
+ event.room_id, extrem_ids
+ )
+ state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
+ state_sets.append(state)
+ current_states = await self.state_handler.resolve_events(
+ room_version, state_sets, event
+ )
+ current_state_ids: StateMap[str] = {
+ k: e.event_id for k, e in current_states.items()
+ }
+ else:
+ current_state_ids = await self.state_handler.get_current_state_ids(
+ event.room_id, latest_event_ids=extrem_ids
+ )
+
+ logger.debug(
+ "Doing soft-fail check for %s: state %s",
+ event.event_id,
+ current_state_ids,
+ )
+
+ # Now check if event pass auth against said current state
+ auth_types = auth_types_for_event(room_version_obj, event)
+ current_state_ids_list = [
+ e for k, e in current_state_ids.items() if k in auth_types
+ ]
+
+ auth_events_map = await self.store.get_events(current_state_ids_list)
+ current_auth_events = {
+ (e.type, e.state_key): e for e in auth_events_map.values()
+ }
+
+ try:
+ event_auth.check(room_version_obj, event, auth_events=current_auth_events)
+ except AuthError as e:
+ logger.warning(
+ "Soft-failing %r (from %s) because %s",
+ event,
+ e,
+ origin,
+ extra={
+ "room_id": event.room_id,
+ "mxid": event.sender,
+ "hs": origin,
+ },
+ )
+ soft_failed_event_counter.inc()
+ event.internal_metadata.soft_failed = True
+
+ async def _update_auth_events_and_context_for_auth(
+ self,
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ input_auth_events: StateMap[EventBase],
+ ) -> Tuple[EventContext, StateMap[EventBase]]:
+ """Helper for _check_event_auth. See there for docs.
+
+ Checks whether a given event has the expected auth events. If it
+ doesn't then we talk to the remote server to compare state to see if
+ we can come to a consensus (e.g. if one server missed some valid
+ state).
+
+ This attempts to resolve any potential divergence of state between
+ servers, but is not essential and so failures should not block further
+ processing of the event.
+
+ Args:
+ origin:
+ event:
+ context:
+
+ input_auth_events:
+ Map from (event_type, state_key) to event
+
+ Normally, our calculated auth_events based on the state of the room
+ at the event's position in the DAG, though occasionally (eg if the
+ event is an outlier), may be the auth events claimed by the remote
+ server.
+
+ Returns:
+ updated context, updated auth event map
+ """
+ # take a copy of input_auth_events before we modify it.
+ auth_events: MutableStateMap[EventBase] = dict(input_auth_events)
+
+ event_auth_events = set(event.auth_event_ids())
+
+ # missing_auth is the set of the event's auth_events which we don't yet have
+ # in auth_events.
+ missing_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
+
+ # if we have missing events, we need to fetch those events from somewhere.
+ #
+ # we start by checking if they are in the store, and then try calling /event_auth/.
+ if missing_auth:
+ have_events = await self.store.have_seen_events(event.room_id, missing_auth)
+ logger.debug("Events %s are in the store", have_events)
+ missing_auth.difference_update(have_events)
+
+ if missing_auth:
+ # If we don't have all the auth events, we need to get them.
+ logger.info("auth_events contains unknown events: %s", missing_auth)
+ try:
+ try:
+ remote_auth_chain = await self.federation_client.get_event_auth(
+ origin, event.room_id, event.event_id
+ )
+ except RequestSendFailed as e1:
+ # The other side isn't around or doesn't implement the
+ # endpoint, so lets just bail out.
+ logger.info("Failed to get event auth from remote: %s", e1)
+ return context, auth_events
+
+ seen_remotes = await self.store.have_seen_events(
+ event.room_id, [e.event_id for e in remote_auth_chain]
+ )
+
+ for e in remote_auth_chain:
+ if e.event_id in seen_remotes:
+ continue
+
+ if e.event_id == event.event_id:
+ continue
+
+ try:
+ auth_ids = e.auth_event_ids()
+ auth = {
+ (e.type, e.state_key): e
+ for e in remote_auth_chain
+ if e.event_id in auth_ids or e.type == EventTypes.Create
+ }
+ e.internal_metadata.outlier = True
+
+ logger.debug(
+ "_check_event_auth %s missing_auth: %s",
+ event.event_id,
+ e.event_id,
+ )
+ missing_auth_event_context = (
+ await self.state_handler.compute_event_context(e)
+ )
+ await self._auth_and_persist_event(
+ origin,
+ e,
+ missing_auth_event_context,
+ claimed_auth_event_map=auth,
+ )
+
+ if e.event_id in event_auth_events:
+ auth_events[(e.type, e.state_key)] = e
+ except AuthError:
+ pass
+
+ except Exception:
+ logger.exception("Failed to get auth chain")
+
+ if event.internal_metadata.is_outlier():
+ # XXX: given that, for an outlier, we'll be working with the
+ # event's *claimed* auth events rather than those we calculated:
+ # (a) is there any point in this test, since different_auth below will
+ # obviously be empty
+ # (b) alternatively, why don't we do it earlier?
+ logger.info("Skipping auth_event fetch for outlier")
+ return context, auth_events
+
+ different_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
+
+ if not different_auth:
+ return context, auth_events
+
+ logger.info(
+ "auth_events refers to events which are not in our calculated auth "
+ "chain: %s",
+ different_auth,
+ )
+
+ # XXX: currently this checks for redactions but I'm not convinced that is
+ # necessary?
+ different_events = await self.store.get_events_as_list(different_auth)
+
+ for d in different_events:
+ if d.room_id != event.room_id:
+ logger.warning(
+ "Event %s refers to auth_event %s which is in a different room",
+ event.event_id,
+ d.event_id,
+ )
+
+ # don't attempt to resolve the claimed auth events against our own
+ # in this case: just use our own auth events.
+ #
+ # XXX: should we reject the event in this case? It feels like we should,
+ # but then shouldn't we also do so if we've failed to fetch any of the
+ # auth events?
+ return context, auth_events
+
+ # now we state-resolve between our own idea of the auth events, and the remote's
+ # idea of them.
+
+ local_state = auth_events.values()
+ remote_auth_events = dict(auth_events)
+ remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
+ remote_state = remote_auth_events.values()
+
+ room_version = await self.store.get_room_version_id(event.room_id)
+ new_state = await self.state_handler.resolve_events(
+ room_version, (local_state, remote_state), event
+ )
+
+ logger.info(
+ "After state res: updating auth_events with new state %s",
+ {
+ (d.type, d.state_key): d.event_id
+ for d in new_state.values()
+ if auth_events.get((d.type, d.state_key)) != d
+ },
+ )
+
+ auth_events.update(new_state)
+
+ context = await self._update_context_for_auth_events(
+ event, context, auth_events
+ )
+
+ return context, auth_events
+
+ async def _update_context_for_auth_events(
+ self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
+ ) -> EventContext:
+ """Update the state_ids in an event context after auth event resolution,
+ storing the changes as a new state group.
+
+ Args:
+ event: The event we're handling the context for
+
+ context: initial event context
+
+ auth_events: Events to update in the event context.
+
+ Returns:
+ new event context
+ """
+ # exclude the state key of the new event from the current_state in the context.
+ if event.is_state():
+ event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
+ else:
+ event_key = None
+ state_updates = {
+ k: a.event_id for k, a in auth_events.items() if k != event_key
+ }
+
+ current_state_ids = await context.get_current_state_ids()
+ current_state_ids = dict(current_state_ids) # type: ignore
+
+ current_state_ids.update(state_updates)
+
+ prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = dict(prev_state_ids)
+
+ prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
+
+ # create a new state group as a delta from the existing one.
+ prev_group = context.state_group
+ state_group = await self.state_store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=prev_group,
+ delta_ids=state_updates,
+ current_state_ids=current_state_ids,
+ )
+
+ return EventContext.with_state(
+ state_group=state_group,
+ state_group_before_event=context.state_group_before_event,
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ prev_group=prev_group,
+ delta_ids=state_updates,
+ )
+
+ async def _run_push_actions_and_persist_event(
+ self, event: EventBase, context: EventContext, backfilled: bool = False
+ ):
+ """Run the push actions for a received event, and persist it.
+
+ Args:
+ event: The event itself.
+ context: The event context.
+ backfilled: True if the event was backfilled.
+ """
+ try:
+ if (
+ not event.internal_metadata.is_outlier()
+ and not backfilled
+ and not context.rejected
+ and (await self.store.get_min_depth(event.room_id)) <= event.depth
+ ):
+ await self.action_generator.handle_push_actions_for_event(
+ event, context
+ )
+
+ await self.persist_events_and_notify(
+ event.room_id, [(event, context)], backfilled=backfilled
+ )
+ except Exception:
+ run_in_background(
+ self.store.remove_push_actions_from_staging, event.event_id
+ )
+ raise
+
+ async def persist_events_and_notify(
+ self,
+ room_id: str,
+ event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
+ backfilled: bool = False,
+ ) -> int:
+ """Persists events and tells the notifier/pushers about them, if
+ necessary.
+
+ Args:
+ room_id: The room ID of events being persisted.
+ event_and_contexts: Sequence of events with their associated
+ context that should be persisted. All events must belong to
+ the same room.
+ backfilled: Whether these events are a result of
+ backfilling or not
+
+ Returns:
+ The stream ID after which all events have been persisted.
+ """
+ if not event_and_contexts:
+ return self.store.get_current_events_token()
+
+ instance = self.config.worker.events_shard_config.get_instance(room_id)
+ if instance != self._instance_name:
+ # Limit the number of events sent over replication. We choose 200
+ # here as that is what we default to in `max_request_body_size(..)`
+ for batch in batch_iter(event_and_contexts, 200):
+ result = await self._send_events(
+ instance_name=instance,
+ store=self.store,
+ room_id=room_id,
+ event_and_contexts=batch,
+ backfilled=backfilled,
+ )
+ return result["max_stream_id"]
+ else:
+ assert self.storage.persistence
+
+ # Note that this returns the events that were persisted, which may not be
+ # the same as were passed in if some were deduplicated due to transaction IDs.
+ events, max_stream_token = await self.storage.persistence.persist_events(
+ event_and_contexts, backfilled=backfilled
+ )
+
+ if self._ephemeral_messages_enabled:
+ for event in events:
+ # If there's an expiry timestamp on the event, schedule its expiry.
+ self._message_handler.maybe_schedule_expiry(event)
+
+ if not backfilled: # Never notify for backfilled events
+ for event in events:
+ await self._notify_persisted_event(event, max_stream_token)
+
+ return max_stream_token.stream
+
+ async def _notify_persisted_event(
+ self, event: EventBase, max_stream_token: RoomStreamToken
+ ) -> None:
+ """Checks to see if notifier/pushers should be notified about the
+ event or not.
+
+ Args:
+ event:
+ max_stream_token: The max_stream_id returned by persist_events
+ """
+
+ extra_users = []
+ if event.type == EventTypes.Member:
+ target_user_id = event.state_key
+
+ # We notify for memberships if its an invite for one of our
+ # users
+ if event.internal_metadata.is_outlier():
+ if event.membership != Membership.INVITE:
+ if not self.is_mine_id(target_user_id):
+ return
+
+ target_user = UserID.from_string(target_user_id)
+ extra_users.append(target_user)
+ elif event.internal_metadata.is_outlier():
+ return
+
+ # the event has been persisted so it should have a stream ordering.
+ assert event.internal_metadata.stream_ordering
+
+ event_pos = PersistedEventPosition(
+ self._instance_name, event.internal_metadata.stream_ordering
+ )
+ self.notifier.on_new_room_event(
+ event, event_pos, max_stream_token, extra_users=extra_users
+ )
+
+ def _sanity_check_event(self, ev: EventBase) -> None:
+ """
+ Do some early sanity checks of a received event
+
+ In particular, checks it doesn't have an excessive number of
+ prev_events or auth_events, which could cause a huge state resolution
+ or cascade of event fetches.
+
+ Args:
+ ev: event to be checked
+
+ Raises:
+ SynapseError if the event does not pass muster
+ """
+ if len(ev.prev_event_ids()) > 20:
+ logger.warning(
+ "Rejecting event %s which has %i prev_events",
+ ev.event_id,
+ len(ev.prev_event_ids()),
+ )
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events")
+
+ if len(ev.auth_event_ids()) > 10:
+ logger.warning(
+ "Rejecting event %s which has %i auth_events",
+ ev.event_id,
+ len(ev.auth_event_ids()),
+ )
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
+
+ async def get_min_depth_for_context(self, context: str) -> int:
+ return await self.store.get_min_depth(context)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 79cadb7b57..a0b3145f4e 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.clock = hs.get_clock()
- self.federation_handler = hs.get_federation_handler()
+ self.federation_event_handler = hs.get_federation_event_handler()
@staticmethod
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
@@ -127,7 +127,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
- max_stream_id = await self.federation_handler.persist_events_and_notify(
+ max_stream_id = await self.federation_event_handler.persist_events_and_notify(
room_id, event_and_contexts, backfilled
)
diff --git a/synapse/server.py b/synapse/server.py
index de6517663e..5adeeff61a 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -76,6 +76,7 @@ from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
from synapse.handlers.event_auth import EventAuthHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.federation import FederationHandler
+from synapse.handlers.federation_event import FederationEventHandler
from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
from synapse.handlers.identity import IdentityHandler
from synapse.handlers.initial_sync import InitialSyncHandler
@@ -547,6 +548,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return FederationHandler(self)
@cache_in_self
+ def get_federation_event_handler(self) -> FederationEventHandler:
+ return FederationEventHandler(self)
+
+ @cache_in_self
def get_identity_handler(self) -> IdentityHandler:
return IdentityHandler(self)
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 383214ab50..663960ff53 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -208,7 +208,7 @@ class FederationKnockingTestCase(
async def _check_event_auth(origin, event, context, *args, **kwargs):
return context
- homeserver.get_federation_handler()._check_event_auth = _check_event_auth
+ homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth
return super().prepare(reactor, clock, homeserver)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index c72a8972a3..6c67a16de9 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -130,7 +130,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
with LoggingContext("send_rejected"):
- d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ d = run_in_background(
+ self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev
+ )
self.get_success(d)
# that should have been rejected
@@ -182,7 +184,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
with LoggingContext("send_rejected"):
- d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ d = run_in_background(
+ self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev
+ )
self.get_success(d)
# that should have been rejected
@@ -311,7 +315,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver
d = run_in_background(
- self.handler.on_receive_pdu, OTHER_SERVER, message_event
+ self.hs.get_federation_event_handler().on_receive_pdu,
+ OTHER_SERVER,
+ message_event,
)
self.get_success(d)
@@ -382,7 +388,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
join_event.signatures[other_server] = {"x": "y"}
with LoggingContext("send_join"):
d = run_in_background(
- self.handler.on_send_membership_event, other_server, join_event
+ self.hs.get_federation_event_handler().on_send_membership_event,
+ other_server,
+ join_event,
)
self.get_success(d)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 0a52bc8b72..671dc7d083 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -885,7 +885,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.federation_sender = hs.get_federation_sender()
self.event_builder_factory = hs.get_event_builder_factory()
- self.federation_handler = hs.get_federation_handler()
+ self.federation_event_handler = hs.get_federation_event_handler()
self.presence_handler = hs.get_presence_handler()
# self.event_builder_for_2 = EventBuilderFactory(hs)
@@ -1026,7 +1026,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
)
- self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
+ self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
# Check that it was successfully persisted.
self.get_success(self.store.get_event(event.event_id))
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index af5dfca752..92a5b53e11 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -205,7 +205,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastore()
- federation = self.hs.get_federation_handler()
+ federation = self.hs.get_federation_event_handler()
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
room_version = self.get_success(store.get_room_version(room))
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 348fcb72a7..61c9d7c2ef 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -75,7 +75,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
self.handler = self.homeserver.get_federation_handler()
- self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
+ federation_event_handler = self.homeserver.get_federation_event_handler()
+ federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
context
)
self.client = self.homeserver.get_federation_client()
@@ -85,7 +86,9 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Send the join, it should return None (which is not an error)
self.assertEqual(
- self.get_success(self.handler.on_receive_pdu("test.serv", join_event)),
+ self.get_success(
+ federation_event_handler.on_receive_pdu("test.serv", join_event)
+ ),
None,
)
@@ -129,9 +132,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
+ federation_event_handler = self.homeserver.get_federation_event_handler()
with LoggingContext("test-context"):
failure = self.get_failure(
- self.handler.on_receive_pdu("test.serv", lying_event),
+ federation_event_handler.on_receive_pdu("test.serv", lying_event),
FederationError,
)
|