diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9f055f00cf..9ec90ac8c1 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -36,6 +36,7 @@ from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
EventTypes,
+ GuestAccess,
Membership,
RejectedReason,
RoomEncryptionAlgorithms,
@@ -53,7 +54,6 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError
-from synapse.handlers._base import BaseHandler
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
@@ -116,7 +116,7 @@ class _NewEventInfo:
claimed_auth_event_map: StateMap[EventBase]
-class FederationEventHandler(BaseHandler):
+class FederationEventHandler:
"""Handles events that originated from federation.
Responsible for handing incoming events and passing them on to the rest
@@ -124,26 +124,28 @@ class FederationEventHandler(BaseHandler):
"""
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self._store = hs.get_datastore()
+ self._storage = hs.get_storage()
+ self._state_store = self._storage.state
- self.store = hs.get_datastore()
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
-
- self.state_handler = hs.get_state_handler()
- self.event_creation_handler = hs.get_event_creation_handler()
+ self._state_handler = hs.get_state_handler()
+ self._event_creation_handler = hs.get_event_creation_handler()
self._event_auth_handler = hs.get_event_auth_handler()
self._message_handler = hs.get_message_handler()
- self.action_generator = hs.get_action_generator()
+ self._action_generator = hs.get_action_generator()
self._state_resolution_handler = hs.get_state_resolution_handler()
+ # avoid a circular dependency by deferring execution here
+ self._get_room_member_handler = hs.get_room_member_handler
- self.federation_client = hs.get_federation_client()
- self.third_party_event_rules = hs.get_third_party_event_rules()
+ self._federation_client = hs.get_federation_client()
+ self._third_party_event_rules = hs.get_third_party_event_rules()
+ self._notifier = hs.get_notifier()
- self.is_mine_id = hs.is_mine_id
+ self._is_mine_id = hs.is_mine_id
+ self._server_name = hs.hostname
self._instance_name = hs.get_instance_name()
- self.config = hs.config
+ self._config = hs.config
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
@@ -171,11 +173,14 @@ class FederationEventHandler(BaseHandler):
pdu: received PDU
"""
+ # We should never see any outliers here.
+ assert not pdu.internal_metadata.outlier
+
room_id = pdu.room_id
event_id = pdu.event_id
# We reprocess pdus when we have seen them only as outliers
- existing = await self.store.get_event(
+ existing = await self._store.get_event(
event_id, allow_none=True, allow_rejected=True
)
@@ -221,7 +226,7 @@ class FederationEventHandler(BaseHandler):
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
is_in_room = await self._event_auth_handler.check_host_in_room(
- room_id, self.server_name
+ room_id, self._server_name
)
if not is_in_room:
logger.info(
@@ -230,77 +235,71 @@ class FederationEventHandler(BaseHandler):
)
return None
- # Check that the event passes auth based on the state at the event. This is
- # done for events that are to be added to the timeline (non-outliers).
- #
- # Get missing pdus if necessary:
- # - Fetching any missing prev events to fill in gaps in the graph
- # - Fetching state if we have a hole in the graph
- if not pdu.internal_metadata.is_outlier():
- prevs = set(pdu.prev_event_ids())
- seen = await self.store.have_events_in_timeline(prevs)
- missing_prevs = prevs - seen
+ # Try to fetch any missing prev events to fill in gaps in the graph
+ prevs = set(pdu.prev_event_ids())
+ seen = await self._store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
- if missing_prevs:
- # We only backfill backwards to the min depth.
- min_depth = await self.get_min_depth_for_context(pdu.room_id)
- logger.debug("min_depth: %d", min_depth)
+ if missing_prevs:
+ # We only backfill backwards to the min depth.
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
+ logger.debug("min_depth: %d", min_depth)
- if min_depth is not None and pdu.depth > min_depth:
- # If we're missing stuff, ensure we only fetch stuff one
- # at a time.
+ if min_depth is not None and pdu.depth > min_depth:
+ # If we're missing stuff, ensure we only fetch stuff one
+ # at a time.
+ logger.info(
+ "Acquiring room lock to fetch %d missing prev_events: %s",
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
- "Acquiring room lock to fetch %d missing prev_events: %s",
+ "Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
- shortstr(missing_prevs),
)
- with (await self._room_pdu_linearizer.queue(pdu.room_id)):
- logger.info(
- "Acquired room lock to fetch %d missing prev_events",
- len(missing_prevs),
+
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
)
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ ) from e
- try:
- await self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
- except Exception as e:
- raise Exception(
- "Error fetching missing prev_events for %s: %s"
- % (event_id, e)
- ) from e
-
- # Update the set of things we've seen after trying to
- # fetch the missing stuff
- seen = await self.store.have_events_in_timeline(prevs)
- missing_prevs = prevs - seen
-
- if not missing_prevs:
- logger.info("Found all missing prev_events")
-
- if missing_prevs:
- # since this event was pushed to us, it is possible for it to
- # become the only forward-extremity in the room, and we would then
- # trust its state to be the state for the whole room. This is very
- # bad. Further, if the event was pushed to us, there is no excuse
- # for us not to have all the prev_events. (XXX: apart from
- # min_depth?)
- #
- # We therefore reject any such events.
- logger.warning(
- "Rejecting: failed to fetch %d prev events: %s",
- len(missing_prevs),
- shortstr(missing_prevs),
- )
- raise FederationError(
- "ERROR",
- 403,
- (
- "Your server isn't divulging details about prev_events "
- "referenced in this event."
- ),
- affected=pdu.event_id,
- )
+ # Update the set of things we've seen after trying to
+ # fetch the missing stuff
+ seen = await self._store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if not missing_prevs:
+ logger.info("Found all missing prev_events")
+
+ if missing_prevs:
+ # since this event was pushed to us, it is possible for it to
+ # become the only forward-extremity in the room, and we would then
+ # trust its state to be the state for the whole room. This is very
+ # bad. Further, if the event was pushed to us, there is no excuse
+ # for us not to have all the prev_events. (XXX: apart from
+ # min_depth?)
+ #
+ # We therefore reject any such events.
+ logger.warning(
+ "Rejecting: failed to fetch %d prev events: %s",
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
+ raise FederationError(
+ "ERROR",
+ 403,
+ (
+ "Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ affected=pdu.event_id,
+ )
await self._process_received_pdu(origin, pdu, state=None)
@@ -361,7 +360,7 @@ class FederationEventHandler(BaseHandler):
# the room, so we send it on their behalf.
event.internal_metadata.send_on_behalf_of = origin
- context = await self.state_handler.compute_event_context(event)
+ context = await self._state_handler.compute_event_context(event)
context = await self._check_event_auth(origin, event, context)
if context.rejected:
raise SynapseError(
@@ -375,7 +374,7 @@ class FederationEventHandler(BaseHandler):
# for knock events, we run the third-party event rules. It's not entirely clear
# why we don't do this for other sorts of membership events.
if event.membership == Membership.KNOCK:
- event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
+ event_allowed, _ = await self._third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@@ -404,7 +403,7 @@ class FederationEventHandler(BaseHandler):
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
prev_member_event = None
if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
+ prev_member_event = await self._store.get_event(prev_member_event_id)
# Check if the member should be allowed access via membership in a space.
await self._event_auth_handler.check_restricted_join_rules(
@@ -434,10 +433,10 @@ class FederationEventHandler(BaseHandler):
server from invalid events (there is probably no point in trying to
re-fetch invalid events from every other HS in the room.)
"""
- if dest == self.server_name:
+ if dest == self._server_name:
raise SynapseError(400, "Can't backfill from self.")
- events = await self.federation_client.backfill(
+ events = await self._federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
@@ -469,12 +468,12 @@ class FederationEventHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- seen = await self.store.have_events_in_timeline(prevs)
+ seen = await self._store.have_events_in_timeline(prevs)
if not prevs - seen:
return
- latest_list = await self.store.get_latest_event_ids_in_room(room_id)
+ latest_list = await self._store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@@ -536,7 +535,7 @@ class FederationEventHandler(BaseHandler):
# All that said: Let's try increasing the timeout to 60s and see what happens.
try:
- missing_events = await self.federation_client.get_missing_events(
+ missing_events = await self._federation_client.get_missing_events(
origin,
room_id,
earliest_events_ids=list(latest),
@@ -609,7 +608,7 @@ class FederationEventHandler(BaseHandler):
event_id = event.event_id
- existing = await self.store.get_event(
+ existing = await self._store.get_event(
event_id, allow_none=True, allow_rejected=True
)
if existing:
@@ -674,7 +673,7 @@ class FederationEventHandler(BaseHandler):
event_id = event.event_id
prevs = set(event.prev_event_ids())
- seen = await self.store.have_events_in_timeline(prevs)
+ seen = await self._store.have_events_in_timeline(prevs)
missing_prevs = prevs - seen
if not missing_prevs:
@@ -691,7 +690,7 @@ class FederationEventHandler(BaseHandler):
event_map = {event_id: event}
try:
# Get the state of the events we know about
- ours = await self.state_store.get_state_groups_ids(room_id, seen)
+ ours = await self._state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps: List[StateMap[str]] = list(ours.values())
@@ -720,13 +719,13 @@ class FederationEventHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
- room_version = await self.store.get_room_version_id(room_id)
+ room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
event_map,
- state_res_store=StateResolutionStore(self.store),
+ state_res_store=StateResolutionStore(self._store),
)
# We need to give _process_received_pdu the actual state events
@@ -734,7 +733,7 @@ class FederationEventHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
- evs = await self.store.get_events(
+ evs = await self._store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.AS_IS,
@@ -774,7 +773,7 @@ class FederationEventHandler(BaseHandler):
(
state_event_ids,
auth_event_ids,
- ) = await self.federation_client.get_room_state_ids(
+ ) = await self._federation_client.get_room_state_ids(
destination, room_id, event_id=event_id
)
@@ -788,7 +787,7 @@ class FederationEventHandler(BaseHandler):
desired_events = set(state_event_ids)
desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events))
- fetched_events = await self.store.get_events(
+ fetched_events = await self._store.get_events(
desired_events, allow_rejected=True
)
@@ -809,20 +808,20 @@ class FederationEventHandler(BaseHandler):
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
missing_auth_events.difference_update(
- await self.store.have_seen_events(room_id, missing_auth_events)
+ await self._store.have_seen_events(room_id, missing_auth_events)
)
logger.debug("We are also missing %i auth events", len(missing_auth_events))
missing_events = missing_desired_events | missing_auth_events
logger.debug("Fetching %i events from remote", len(missing_events))
await self._get_events_and_persist(
- destination=destination, room_id=room_id, events=missing_events
+ destination=destination, room_id=room_id, event_ids=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
- await self.store.get_events(missing_desired_events, allow_rejected=True)
+ await self._store.get_events(missing_desired_events, allow_rejected=True)
)
# check for events which were in the wrong room.
@@ -883,8 +882,13 @@ class FederationEventHandler(BaseHandler):
state: Optional[Iterable[EventBase]],
backfilled: bool = False,
) -> None:
- """Called when we have a new pdu. We need to do auth checks and put it
- through the StateHandler.
+ """Called when we have a new non-outlier event.
+
+ This is called when we have a new event to add to the room DAG - either directly
+ via a /send request, retrieved via get_missing_events after a /send request, or
+ backfilled after a client request.
+
+ We need to do auth checks and put it through the StateHandler.
Args:
origin: server sending the event
@@ -899,17 +903,24 @@ class FederationEventHandler(BaseHandler):
notification to clients, and validation of device keys.)
"""
logger.debug("Processing event: %s", event)
+ assert not event.internal_metadata.outlier
try:
- context = await self.state_handler.compute_event_context(
+ context = await self._state_handler.compute_event_context(
event, old_state=state
)
- await self._auth_and_persist_event(
- origin, event, context, state=state, backfilled=backfilled
+ context = await self._check_event_auth(
+ origin,
+ event,
+ context,
+ state=state,
+ backfilled=backfilled,
)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
+ await self._run_push_actions_and_persist_event(event, context, backfilled)
+
if backfilled:
return
@@ -919,7 +930,7 @@ class FederationEventHandler(BaseHandler):
device_id = event.content.get("device_id")
sender_key = event.content.get("sender_key")
- cached_devices = await self.store.get_cached_devices_for_user(event.sender)
+ cached_devices = await self._store.get_cached_devices_for_user(event.sender)
resync = False # Whether we should resync device lists.
@@ -995,10 +1006,10 @@ class FederationEventHandler(BaseHandler):
"""
try:
- await self.store.mark_remote_user_device_cache_as_stale(sender)
+ await self._store.mark_remote_user_device_cache_as_stale(sender)
# Immediately attempt a resync in the background
- if self.config.worker_app:
+ if self._config.worker_app:
await self._user_device_resync(user_id=sender)
else:
await self._device_list_updater.user_device_resync(sender)
@@ -1023,9 +1034,15 @@ class FederationEventHandler(BaseHandler):
return
# Skip processing a marker event if the room version doesn't
- # support it.
- room_version = await self.store.get_room_version(marker_event.room_id)
- if not room_version.msc2716_historical:
+ # support it or the event is not from the room creator.
+ room_version = await self._store.get_room_version(marker_event.room_id)
+ create_event = await self._store.get_create_event_for_room(marker_event.room_id)
+ room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
+ if (
+ not room_version.msc2716_historical
+ or not self._config.experimental.msc2716_enabled
+ or marker_event.sender != room_creator
+ ):
return
logger.debug("_handle_marker_event: received %s", marker_event)
@@ -1048,7 +1065,7 @@ class FederationEventHandler(BaseHandler):
[insertion_event_id],
)
- insertion_event = await self.store.get_event(
+ insertion_event = await self._store.get_event(
insertion_event_id, allow_none=True
)
if insertion_event is None:
@@ -1066,7 +1083,7 @@ class FederationEventHandler(BaseHandler):
marker_event,
)
- await self.store.insert_insertion_extremity(
+ await self._store.insert_insertion_extremity(
insertion_event_id, marker_event.room_id
)
@@ -1077,25 +1094,25 @@ class FederationEventHandler(BaseHandler):
)
async def _get_events_and_persist(
- self, destination: str, room_id: str, events: Iterable[str]
+ self, destination: str, room_id: str, event_ids: Collection[str]
) -> None:
"""Fetch the given events from a server, and persist them as outliers.
This function *does not* recursively get missing auth events of the
- newly fetched events. Callers must include in the `events` argument
+ newly fetched events. Callers must include in the `event_ids` argument
any missing events from the auth chain.
Logs a warning if we can't find the given event.
"""
- room_version = await self.store.get_room_version(room_id)
+ room_version = await self._store.get_room_version(room_id)
event_map: Dict[str, EventBase] = {}
async def get_event(event_id: str):
with nested_logging_context(event_id):
try:
- event = await self.federation_client.get_pdu(
+ event = await self._federation_client.get_pdu(
[destination],
event_id,
room_version,
@@ -1119,28 +1136,78 @@ class FederationEventHandler(BaseHandler):
e,
)
- await concurrently_execute(get_event, events, 5)
+ await concurrently_execute(get_event, event_ids, 5)
+ logger.info("Fetched %i events of %i requested", len(event_map), len(event_ids))
- # Make a map of auth events for each event. We do this after fetching
- # all the events as some of the events' auth events will be in the list
- # of requested events.
+ # we now need to auth the events in an order which ensures that each event's
+ # auth_events are authed before the event itself.
+ #
+ # XXX: it might be possible to kick this process off in parallel with fetching
+ # the events.
+ while event_map:
+ # build a list of events whose auth events are not in the queue.
+ roots = tuple(
+ ev
+ for ev in event_map.values()
+ if not any(aid in event_map for aid in ev.auth_event_ids())
+ )
- auth_events = [
- aid
- for event in event_map.values()
- for aid in event.auth_event_ids()
- if aid not in event_map
- ]
- persisted_events = await self.store.get_events(
+ if not roots:
+ # if *none* of the remaining events are ready, that means
+ # we have a loop. This either means a bug in our logic, or that
+ # somebody has managed to create a loop (which requires finding a
+ # hash collision in room v2 and later).
+ logger.warning(
+ "Loop found in auth events while fetching missing state/auth "
+ "events: %s",
+ shortstr(event_map.keys()),
+ )
+ return
+
+ logger.info(
+ "Persisting %i of %i remaining events", len(roots), len(event_map)
+ )
+
+ await self._auth_and_persist_fetched_events(destination, room_id, roots)
+
+ for ev in roots:
+ del event_map[ev.event_id]
+
+ async def _auth_and_persist_fetched_events(
+ self, origin: str, room_id: str, fetched_events: Collection[EventBase]
+ ) -> None:
+ """Persist the events fetched by _get_events_and_persist.
+
+ The events should not depend on one another, e.g. this should be used to persist
+ a bunch of outliers, but not a chunk of individual events that depend
+ on each other for state calculations.
+
+ We also assume that all of the auth events for all of the events have already
+ been persisted.
+
+ Notifies about the events where appropriate.
+
+ Params:
+ origin: where the events came from
+ room_id: the room that the events are meant to be in (though this has
+ not yet been checked)
+ event_id: map from event_id -> event for the fetched events
+ """
+ # get all the auth events for all the events in this batch. By now, they should
+ # have been persisted.
+ auth_events = {
+ aid for event in fetched_events for aid in event.auth_event_ids()
+ }
+ persisted_events = await self._store.get_events(
auth_events,
allow_rejected=True,
)
event_infos = []
- for event in event_map.values():
+ for event in fetched_events:
auth = {}
for auth_event_id in event.auth_event_ids():
- ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id)
+ ae = persisted_events.get(auth_event_id)
if ae:
auth[(ae.type, ae.state_key)] = ae
else:
@@ -1148,34 +1215,13 @@ class FederationEventHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, auth))
- if event_infos:
- await self._auth_and_persist_events(
- destination,
- room_id,
- event_infos,
- )
-
- async def _auth_and_persist_events(
- self,
- origin: str,
- room_id: str,
- event_infos: Collection[_NewEventInfo],
- ) -> None:
- """Creates the appropriate contexts and persists events. The events
- should not depend on one another, e.g. this should be used to persist
- a bunch of outliers, but not a chunk of individual events that depend
- on each other for state calculations.
-
- Notifies about the events where appropriate.
- """
-
if not event_infos:
return
async def prep(ev_info: _NewEventInfo):
event = ev_info.event
with nested_logging_context(suffix=event.event_id):
- res = await self.state_handler.compute_event_context(event)
+ res = await self._state_handler.compute_event_context(event)
res = await self._check_event_auth(
origin,
event,
@@ -1199,49 +1245,6 @@ class FederationEventHandler(BaseHandler):
],
)
- async def _auth_and_persist_event(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- state: Optional[Iterable[EventBase]] = None,
- claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
- backfilled: bool = False,
- ) -> None:
- """
- Process an event by performing auth checks and then persisting to the database.
-
- Args:
- origin: The host the event originates from.
- event: The event itself.
- context:
- The event context.
-
- state:
- The state events used to check the event for soft-fail. If this is
- not provided the current state events will be used.
-
- claimed_auth_event_map:
- A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly incomplete, and possibly including events that are not yet
- persisted, or authed, or in the right room.
-
- Only populated where we may not already have persisted these events -
- for example, when populating outliers.
-
- backfilled: True if the event was backfilled.
- """
- context = await self._check_event_auth(
- origin,
- event,
- context,
- state=state,
- claimed_auth_event_map=claimed_auth_event_map,
- backfilled=backfilled,
- )
-
- await self._run_push_actions_and_persist_event(event, context, backfilled)
-
async def _check_event_auth(
self,
origin: str,
@@ -1269,16 +1272,17 @@ class FederationEventHandler(BaseHandler):
Possibly incomplete, and possibly including events that are not yet
persisted, or authed, or in the right room.
- Only populated where we may not already have persisted these events -
- for example, when populating outliers, or the state for a backwards
- extremity.
+ Only populated when populating outliers.
backfilled: True if the event was backfilled.
Returns:
The updated context object.
"""
- room_version = await self.store.get_room_version_id(event.room_id)
+ # claimed_auth_event_map should be given iff the event is an outlier
+ assert bool(claimed_auth_event_map) == event.internal_metadata.outlier
+
+ room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if claimed_auth_event_map:
@@ -1291,7 +1295,7 @@ class FederationEventHandler(BaseHandler):
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events_x = await self.store.get_events(auth_events_ids)
+ auth_events_x = await self._store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
try:
@@ -1321,19 +1325,29 @@ class FederationEventHandler(BaseHandler):
if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled, origin=origin)
-
- if event.type == EventTypes.GuestAccess and not context.rejected:
- await self.maybe_kick_guest_users(event)
+ await self._maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
- await self.event_creation_handler.cache_joined_hosts_for_event(
+ await self._event_creation_handler.cache_joined_hosts_for_event(
event, context
)
return context
+ async def _maybe_kick_guest_users(self, event: EventBase) -> None:
+ if event.type != EventTypes.GuestAccess:
+ return
+
+ guest_access = event.content.get(EventContentFields.GUEST_ACCESS)
+ if guest_access == GuestAccess.CAN_JOIN:
+ return
+
+ current_state_map = await self._state_handler.get_current_state(event.room_id)
+ current_state = list(current_state_map.values())
+ await self._get_room_member_handler().kick_guest_users(current_state)
+
async def _check_for_soft_fail(
self,
event: EventBase,
@@ -1356,7 +1370,7 @@ class FederationEventHandler(BaseHandler):
if backfilled or event.internal_metadata.is_outlier():
return
- extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
+ extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())
@@ -1365,7 +1379,7 @@ class FederationEventHandler(BaseHandler):
# state at the event, so no point rechecking auth for soft fail.
return
- room_version = await self.store.get_room_version_id(event.room_id)
+ room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# Calculate the "current state".
@@ -1382,19 +1396,19 @@ class FederationEventHandler(BaseHandler):
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
- state_sets_d = await self.state_store.get_state_groups(
+ state_sets_d = await self._state_store.get_state_groups(
event.room_id, extrem_ids
)
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets.append(state)
- current_states = await self.state_handler.resolve_events(
+ current_states = await self._state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
}
else:
- current_state_ids = await self.state_handler.get_current_state_ids(
+ current_state_ids = await self._state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
)
@@ -1410,7 +1424,7 @@ class FederationEventHandler(BaseHandler):
e for k, e in current_state_ids.items() if k in auth_types
]
- auth_events_map = await self.store.get_events(current_state_ids_list)
+ auth_events_map = await self._store.get_events(current_state_ids_list)
current_auth_events = {
(e.type, e.state_key): e for e in auth_events_map.values()
}
@@ -1481,7 +1495,9 @@ class FederationEventHandler(BaseHandler):
#
# we start by checking if they are in the store, and then try calling /event_auth/.
if missing_auth:
- have_events = await self.store.have_seen_events(event.room_id, missing_auth)
+ have_events = await self._store.have_seen_events(
+ event.room_id, missing_auth
+ )
logger.debug("Events %s are in the store", have_events)
missing_auth.difference_update(have_events)
@@ -1490,7 +1506,7 @@ class FederationEventHandler(BaseHandler):
logger.info("auth_events contains unknown events: %s", missing_auth)
try:
try:
- remote_auth_chain = await self.federation_client.get_event_auth(
+ remote_auth_chain = await self._federation_client.get_event_auth(
origin, event.room_id, event.event_id
)
except RequestSendFailed as e1:
@@ -1499,43 +1515,49 @@ class FederationEventHandler(BaseHandler):
logger.info("Failed to get event auth from remote: %s", e1)
return context, auth_events
- seen_remotes = await self.store.have_seen_events(
+ seen_remotes = await self._store.have_seen_events(
event.room_id, [e.event_id for e in remote_auth_chain]
)
- for e in remote_auth_chain:
- if e.event_id in seen_remotes:
+ for auth_event in remote_auth_chain:
+ if auth_event.event_id in seen_remotes:
continue
- if e.event_id == event.event_id:
+ if auth_event.event_id == event.event_id:
continue
try:
- auth_ids = e.auth_event_ids()
+ auth_ids = auth_event.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
- e.internal_metadata.outlier = True
+ auth_event.internal_metadata.outlier = True
logger.debug(
"_check_event_auth %s missing_auth: %s",
event.event_id,
- e.event_id,
+ auth_event.event_id,
)
missing_auth_event_context = (
- await self.state_handler.compute_event_context(e)
+ await self._state_handler.compute_event_context(auth_event)
)
- await self._auth_and_persist_event(
+
+ missing_auth_event_context = await self._check_event_auth(
origin,
- e,
+ auth_event,
missing_auth_event_context,
claimed_auth_event_map=auth,
)
+ await self.persist_events_and_notify(
+ event.room_id, [(auth_event, missing_auth_event_context)]
+ )
- if e.event_id in event_auth_events:
- auth_events[(e.type, e.state_key)] = e
+ if auth_event.event_id in event_auth_events:
+ auth_events[
+ (auth_event.type, auth_event.state_key)
+ ] = auth_event
except AuthError:
pass
@@ -1566,7 +1588,7 @@ class FederationEventHandler(BaseHandler):
# XXX: currently this checks for redactions but I'm not convinced that is
# necessary?
- different_events = await self.store.get_events_as_list(different_auth)
+ different_events = await self._store.get_events_as_list(different_auth)
for d in different_events:
if d.room_id != event.room_id:
@@ -1592,8 +1614,8 @@ class FederationEventHandler(BaseHandler):
remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
remote_state = remote_auth_events.values()
- room_version = await self.store.get_room_version_id(event.room_id)
- new_state = await self.state_handler.resolve_events(
+ room_version = await self._store.get_room_version_id(event.room_id)
+ new_state = await self._state_handler.resolve_events(
room_version, (local_state, remote_state), event
)
@@ -1651,7 +1673,7 @@ class FederationEventHandler(BaseHandler):
# create a new state group as a delta from the existing one.
prev_group = context.state_group
- state_group = await self.state_store.store_state_group(
+ state_group = await self._state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
@@ -1678,14 +1700,17 @@ class FederationEventHandler(BaseHandler):
context: The event context.
backfilled: True if the event was backfilled.
"""
+ # this method should not be called on outliers (those code paths call
+ # persist_events_and_notify directly.)
+ assert not event.internal_metadata.outlier
+
try:
if (
- not event.internal_metadata.is_outlier()
- and not backfilled
+ not backfilled
and not context.rejected
- and (await self.store.get_min_depth(event.room_id)) <= event.depth
+ and (await self._store.get_min_depth(event.room_id)) <= event.depth
):
- await self.action_generator.handle_push_actions_for_event(
+ await self._action_generator.handle_push_actions_for_event(
event, context
)
@@ -1694,7 +1719,7 @@ class FederationEventHandler(BaseHandler):
)
except Exception:
run_in_background(
- self.store.remove_push_actions_from_staging, event.event_id
+ self._store.remove_push_actions_from_staging, event.event_id
)
raise
@@ -1719,27 +1744,27 @@ class FederationEventHandler(BaseHandler):
The stream ID after which all events have been persisted.
"""
if not event_and_contexts:
- return self.store.get_current_events_token()
+ return self._store.get_current_events_token()
- instance = self.config.worker.events_shard_config.get_instance(room_id)
+ instance = self._config.worker.events_shard_config.get_instance(room_id)
if instance != self._instance_name:
# Limit the number of events sent over replication. We choose 200
# here as that is what we default to in `max_request_body_size(..)`
for batch in batch_iter(event_and_contexts, 200):
result = await self._send_events(
instance_name=instance,
- store=self.store,
+ store=self._store,
room_id=room_id,
event_and_contexts=batch,
backfilled=backfilled,
)
return result["max_stream_id"]
else:
- assert self.storage.persistence
+ assert self._storage.persistence
# Note that this returns the events that were persisted, which may not be
# the same as were passed in if some were deduplicated due to transaction IDs.
- events, max_stream_token = await self.storage.persistence.persist_events(
+ events, max_stream_token = await self._storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)
@@ -1773,7 +1798,7 @@ class FederationEventHandler(BaseHandler):
# users
if event.internal_metadata.is_outlier():
if event.membership != Membership.INVITE:
- if not self.is_mine_id(target_user_id):
+ if not self._is_mine_id(target_user_id):
return
target_user = UserID.from_string(target_user_id)
@@ -1787,7 +1812,7 @@ class FederationEventHandler(BaseHandler):
event_pos = PersistedEventPosition(
self._instance_name, event.internal_metadata.stream_ordering
)
- self.notifier.on_new_room_event(
+ self._notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
)
@@ -1822,4 +1847,4 @@ class FederationEventHandler(BaseHandler):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
async def get_min_depth_for_context(self, context: str) -> int:
- return await self.store.get_min_depth(context)
+ return await self._store.get_min_depth(context)
|