diff --git a/changelog.d/6503.misc b/changelog.d/6503.misc
new file mode 100644
index 0000000000..e4e9a5a3d4
--- /dev/null
+++ b/changelog.d/6503.misc
@@ -0,0 +1 @@
+Move get_state methods into FederationHandler.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 73e1dda6a3..d396e6564f 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -37,9 +37,9 @@ from synapse.api.room_versions import (
)
from synapse.events import builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable
from synapse.logging.utils import log_function
-from synapse.util import batch_iter, unwrapFirstError
+from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@@ -308,19 +308,12 @@ class FederationClient(FederationBase):
return signed_pdu
@defer.inlineCallbacks
- @log_function
- def get_state_for_room(self, destination, room_id, event_id):
- """Requests all of the room state at a given event from a remote homeserver.
-
- Args:
- destination (str): The remote homeserver to query for the state.
- room_id (str): The id of the room we're interested in.
- event_id (str): The id of the event we want the state at.
+ def get_room_state_ids(self, destination: str, room_id: str, event_id: str):
+ """Calls the /state_ids endpoint to fetch the state at a particular point
+ in the room, and the auth events for the given event
Returns:
- Deferred[Tuple[List[EventBase], List[EventBase]]]:
- A list of events in the state, and a list of events in the auth chain
- for the given event.
+ Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids)
"""
result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id
@@ -329,74 +322,12 @@ class FederationClient(FederationBase):
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
- desired_events = set(state_event_ids + auth_event_ids)
- event_map = yield self.get_events_from_store_or_dest(
- destination, room_id, desired_events
- )
-
- failed_to_fetch = desired_events - event_map.keys()
- if failed_to_fetch:
- logger.warning(
- "Failed to fetch missing state/auth events for %s: %s",
- room_id,
- failed_to_fetch,
- )
-
- pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
- auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
-
- auth_chain.sort(key=lambda e: e.depth)
-
- return pdus, auth_chain
-
- @defer.inlineCallbacks
- def get_events_from_store_or_dest(self, destination, room_id, event_ids):
- """Fetch events from a remote destination, checking if we already have them.
-
- Args:
- destination (str)
- room_id (str)
- event_ids (Iterable[str])
-
- Returns:
- Deferred[dict[str, EventBase]]: A deferred resolving to a map
- from event_id to event
- """
- fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
-
- missing_events = set(event_ids) - fetched_events.keys()
-
- if not missing_events:
- return fetched_events
-
- logger.debug(
- "Fetching unknown state/auth events %s for room %s",
- missing_events,
- event_ids,
- )
-
- room_version = yield self.store.get_room_version(room_id)
-
- # XXX 20 requests at once? really?
- for batch in batch_iter(missing_events, 20):
- deferreds = [
- run_in_background(
- self.get_pdu,
- destinations=[destination],
- event_id=e_id,
- room_version=room_version,
- )
- for e_id in batch
- ]
-
- res = yield make_deferred_yieldable(
- defer.DeferredList(deferreds, consumeErrors=True)
- )
- for success, result in res:
- if success and result:
- fetched_events[result.event_id] = result
+ if not isinstance(state_event_ids, list) or not isinstance(
+ auth_event_ids, list
+ ):
+ raise Exception("invalid response from /state_ids")
- return fetched_events
+ return state_event_ids, auth_event_ids
@defer.inlineCallbacks
@log_function
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bc26921768..c0dcf9abf8 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -64,7 +64,7 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util import batch_iter, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@@ -379,11 +379,9 @@ class FederationHandler(BaseHandler):
(
remote_state,
got_auth_chain,
- ) = yield self.federation_client.get_state_for_room(
- origin, room_id, p
- )
+ ) = yield self._get_state_for_room(origin, room_id, p)
- # we want the state *after* p; get_state_for_room returns the
+ # we want the state *after* p; _get_state_for_room returns the
# state *before* p.
remote_event = yield self.federation_client.get_pdu(
[origin], p, room_version, outlier=True
@@ -584,6 +582,97 @@ class FederationHandler(BaseHandler):
raise
@defer.inlineCallbacks
+ @log_function
+ def _get_state_for_room(self, destination, room_id, event_id):
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination (str): The remote homeserver to query for the state.
+ room_id (str): The id of the room we're interested in.
+ event_id (str): The id of the event we want the state at.
+
+ Returns:
+ Deferred[Tuple[List[EventBase], List[EventBase]]]:
+ A list of events in the state, and a list of events in the auth chain
+ for the given event.
+ """
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = yield self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
+
+ desired_events = set(state_event_ids + auth_event_ids)
+ event_map = yield self._get_events_from_store_or_dest(
+ destination, room_id, desired_events
+ )
+
+ failed_to_fetch = desired_events - event_map.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state/auth events for %s: %s",
+ room_id,
+ failed_to_fetch,
+ )
+
+ pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
+ auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+
+ auth_chain.sort(key=lambda e: e.depth)
+
+ return pdus, auth_chain
+
+ @defer.inlineCallbacks
+ def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+ """Fetch events from a remote destination, checking if we already have them.
+
+ Args:
+ destination (str)
+ room_id (str)
+ event_ids (Iterable[str])
+
+ Returns:
+ Deferred[dict[str, EventBase]]: A deferred resolving to a map
+ from event_id to event
+ """
+ fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+
+ missing_events = set(event_ids) - fetched_events.keys()
+
+ if not missing_events:
+ return fetched_events
+
+ logger.debug(
+ "Fetching unknown state/auth events %s for room %s",
+ missing_events,
+ event_ids,
+ )
+
+ room_version = yield self.store.get_room_version(room_id)
+
+ # XXX 20 requests at once? really?
+ for batch in batch_iter(missing_events, 20):
+ deferreds = [
+ run_in_background(
+ self.federation_client.get_pdu,
+ destinations=[destination],
+ event_id=e_id,
+ room_version=room_version,
+ )
+ for e_id in batch
+ ]
+
+ res = yield make_deferred_yieldable(
+ defer.DeferredList(deferreds, consumeErrors=True)
+ )
+ for success, result in res:
+ if success and result:
+ fetched_events[result.event_id] = result
+
+ return fetched_events
+
+ @defer.inlineCallbacks
def _process_received_pdu(self, origin, event, state, auth_chain):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
@@ -723,7 +812,7 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = yield self.federation_client.get_state_for_room(
+ state, auth = yield self._get_state_for_room(
destination=dest, room_id=room_id, event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
|