diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 1d050e54e2..ac0f2ccfb3 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -34,7 +34,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
-from synapse.api.constants import EduTypes, EventTypes
+from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -46,6 +46,7 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
@@ -107,9 +108,9 @@ class FederationServer(FederationBase):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.auth = hs.get_auth()
self.handler = hs.get_federation_handler()
self.state = hs.get_state_handler()
+ self._event_auth_handler = hs.get_event_auth_handler()
self.device_handler = hs.get_device_handler()
@@ -147,6 +148,41 @@ class FederationServer(FederationBase):
self._room_prejoin_state_types = hs.config.api.room_prejoin_state
+ # Whether we have started handling old events in the staging area.
+ self._started_handling_of_staged_events = False
+
+ @wrap_as_background_process("_handle_old_staged_events")
+ async def _handle_old_staged_events(self) -> None:
+ """Handle old staged events by fetching all rooms that have staged
+ events and start the processing of each of those rooms.
+ """
+
+ # Get all the rooms IDs with staged events.
+ room_ids = await self.store.get_all_rooms_with_staged_incoming_events()
+
+ # We then shuffle them so that if there are multiple instances doing
+ # this work they're less likely to collide.
+ random.shuffle(room_ids)
+
+ for room_id in room_ids:
+ room_version = await self.store.get_room_version(room_id)
+
+ # Try and acquire the processing lock for the room, if we get it start a
+ # background process for handling the events in the room.
+ lock = await self.store.try_acquire_lock(
+ _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
+ )
+ if lock:
+ logger.info("Handling old staged inbound events in %s", room_id)
+ self._process_incoming_pdus_in_room_inner(
+ room_id,
+ room_version,
+ lock,
+ )
+
+ # We pause a bit so that we don't start handling all rooms at once.
+ await self._clock.sleep(random.uniform(0, 0.1))
+
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
@@ -165,6 +201,12 @@ class FederationServer(FederationBase):
async def on_incoming_transaction(
self, origin: str, transaction_data: JsonDict
) -> Tuple[int, Dict[str, Any]]:
+ # If we receive a transaction we should make sure that kick off handling
+ # any old events in the staging area.
+ if not self._started_handling_of_staged_events:
+ self._started_handling_of_staged_events = True
+ self._handle_old_staged_events()
+
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
@@ -368,22 +410,21 @@ class FederationServer(FederationBase):
async def process_pdu(pdu: EventBase) -> JsonDict:
event_id = pdu.event_id
- with pdu_process_time.time():
- with nested_logging_context(event_id):
- try:
- await self._handle_received_pdu(origin, pdu)
- return {}
- except FederationError as e:
- logger.warning("Error handling PDU %s: %s", event_id, e)
- return {"error": str(e)}
- except Exception as e:
- f = failure.Failure()
- logger.error(
- "Failed to handle PDU %s",
- event_id,
- exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
- )
- return {"error": str(e)}
+ with nested_logging_context(event_id):
+ try:
+ await self._handle_received_pdu(origin, pdu)
+ return {}
+ except FederationError as e:
+ logger.warning("Error handling PDU %s: %s", event_id, e)
+ return {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ logger.error(
+ "Failed to handle PDU %s",
+ event_id,
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
+ )
+ return {"error": str(e)}
await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
@@ -420,7 +461,7 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
- in_room = await self.auth.check_host_in_room(room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -453,7 +494,7 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
- in_room = await self.auth.check_host_in_room(room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -544,26 +585,21 @@ class FederationServer(FederationBase):
return {"event": ret_pdu.get_pdu_json(time_now)}
async def on_send_join_request(
- self, origin: str, content: JsonDict
+ self, origin: str, content: JsonDict, room_id: str
) -> Dict[str, Any]:
- logger.debug("on_send_join_request: content: %s", content)
-
- assert_params_in_dict(content, ["room_id"])
- room_version = await self.store.get_room_version(content["room_id"])
- pdu = event_from_pdu_json(content, room_version)
-
- origin_host, _ = parse_server_name(origin)
- await self.check_server_matches_acl(origin_host, pdu.room_id)
-
- logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
+ context = await self._on_send_membership_event(
+ origin, content, Membership.JOIN, room_id
+ )
- pdu = await self._check_sigs_and_hash(room_version, pdu)
+ prev_state_ids = await context.get_prev_state_ids()
+ state_ids = list(prev_state_ids.values())
+ auth_chain = await self.store.get_auth_chain(room_id, state_ids)
+ state = await self.store.get_events(state_ids)
- res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
return {
- "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+ "state": [p.get_pdu_json(time_now) for p in state.values()],
+ "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
}
async def on_make_leave_request(
@@ -578,21 +614,11 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict:
+ async def on_send_leave_request(
+ self, origin: str, content: JsonDict, room_id: str
+ ) -> dict:
logger.debug("on_send_leave_request: content: %s", content)
-
- assert_params_in_dict(content, ["room_id"])
- room_version = await self.store.get_room_version(content["room_id"])
- pdu = event_from_pdu_json(content, room_version)
-
- origin_host, _ = parse_server_name(origin)
- await self.check_server_matches_acl(origin_host, pdu.room_id)
-
- logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
-
- pdu = await self._check_sigs_and_hash(room_version, pdu)
-
- await self.handler.on_send_leave_request(origin, pdu)
+ await self._on_send_membership_event(origin, content, Membership.LEAVE, room_id)
return {}
async def on_make_knock_request(
@@ -658,39 +684,76 @@ class FederationServer(FederationBase):
Returns:
The stripped room state.
"""
- logger.debug("on_send_knock_request: content: %s", content)
+ event_context = await self._on_send_membership_event(
+ origin, content, Membership.KNOCK, room_id
+ )
+
+ # Retrieve stripped state events from the room and send them back to the remote
+ # server. This will allow the remote server's clients to display information
+ # related to the room while the knock request is pending.
+ stripped_room_state = (
+ await self.store.get_stripped_room_state_from_event_context(
+ event_context, self._room_prejoin_state_types
+ )
+ )
+ return {"knock_state_events": stripped_room_state}
+
+ async def _on_send_membership_event(
+ self, origin: str, content: JsonDict, membership_type: str, room_id: str
+ ) -> EventContext:
+ """Handle an on_send_{join,leave,knock} request
+
+ Does some preliminary validation before passing the request on to the
+ federation handler.
+
+ Args:
+ origin: The (authenticated) requesting server
+ content: The body of the send_* request - a complete membership event
+ membership_type: The expected membership type (join or leave, depending
+ on the endpoint)
+ room_id: The room_id from the request, to be validated against the room_id
+ in the event
+
+ Returns:
+ The context of the event after inserting it into the room graph.
+
+ Raises:
+ SynapseError if there is a problem with the request, including things like
+ the room_id not matching or the event not being authorized.
+ """
+ assert_params_in_dict(content, ["room_id"])
+ if content["room_id"] != room_id:
+ raise SynapseError(
+ 400,
+ "Room ID in body does not match that in request path",
+ Codes.BAD_JSON,
+ )
room_version = await self.store.get_room_version(room_id)
- # Check that this room supports knocking as defined by its room version
- if not room_version.msc2403_knocking:
+ if membership_type == Membership.KNOCK and not room_version.msc2403_knocking:
raise SynapseError(
403,
"This room version does not support knocking",
errcode=Codes.FORBIDDEN,
)
- pdu = event_from_pdu_json(content, room_version)
+ event = event_from_pdu_json(content, room_version)
- origin_host, _ = parse_server_name(origin)
- await self.check_server_matches_acl(origin_host, pdu.room_id)
+ if event.type != EventTypes.Member or not event.is_state():
+ raise SynapseError(400, "Not an m.room.member event", Codes.BAD_JSON)
- logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures)
+ if event.content.get("membership") != membership_type:
+ raise SynapseError(400, "Not a %s event" % membership_type, Codes.BAD_JSON)
- pdu = await self._check_sigs_and_hash(room_version, pdu)
+ origin_host, _ = parse_server_name(origin)
+ await self.check_server_matches_acl(origin_host, event.room_id)
- # Handle the event, and retrieve the EventContext
- event_context = await self.handler.on_send_knock_request(origin, pdu)
+ logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
- # Retrieve stripped state events from the room and send them back to the remote
- # server. This will allow the remote server's clients to display information
- # related to the room while the knock request is pending.
- stripped_room_state = (
- await self.store.get_stripped_room_state_from_event_context(
- event_context, self._room_prejoin_state_types
- )
- )
- return {"knock_state_events": stripped_room_state}
+ event = await self._check_sigs_and_hash(room_version, event)
+
+ return await self.handler.on_send_membership_event(origin, event)
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
@@ -860,32 +923,39 @@ class FederationServer(FederationBase):
room_id: str,
room_version: RoomVersion,
lock: Lock,
- latest_origin: str,
- latest_event: EventBase,
+ latest_origin: Optional[str] = None,
+ latest_event: Optional[EventBase] = None,
) -> None:
"""Process events in the staging area for the given room.
The latest_origin and latest_event args are the latest origin and event
- received.
+ received (or None to simply pull the next event from the database).
"""
# The common path is for the event we just received be the only event in
# the room, so instead of pulling the event out of the DB and parsing
# the event we just pull out the next event ID and check if that matches.
- next_origin, next_event_id = await self.store.get_next_staged_event_id_for_room(
- room_id
- )
- if next_origin == latest_origin and next_event_id == latest_event.event_id:
- origin = latest_origin
- event = latest_event
- else:
+ if latest_event is not None and latest_origin is not None:
+ (
+ next_origin,
+ next_event_id,
+ ) = await self.store.get_next_staged_event_id_for_room(room_id)
+ if next_origin != latest_origin or next_event_id != latest_event.event_id:
+ latest_origin = None
+ latest_event = None
+
+ if latest_origin is None or latest_event is None:
next = await self.store.get_next_staged_event_for_room(
room_id, room_version
)
if not next:
+ await lock.release()
return
origin, event = next
+ else:
+ origin = latest_origin
+ event = latest_event
# We loop round until there are no more events in the room in the
# staging area, or we fail to get the lock (which means another process
@@ -909,9 +979,13 @@ class FederationServer(FederationBase):
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
- await self.store.remove_received_event_from_staging(
+ received_ts = await self.store.remove_received_event_from_staging(
origin, event.event_id
)
+ if received_ts is not None:
+ pdu_process_time.observe(
+ (self._clock.time_msec() - received_ts) / 1000
+ )
# We need to do this check outside the lock to avoid a race between
# a new event being inserted by another instance and it attempting
|