summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation.py13
-rw-r--r--synapse/state/__init__.py92
2 files changed, 55 insertions, 50 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0073e7c996..1a8144405a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,7 +21,7 @@ import itertools
 import logging
 from collections.abc import Container
 from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -69,7 +69,7 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEventsRestServlet,
     ReplicationStoreRoomOnInviteRestServlet,
 )
-from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.state import StateResolutionStore
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
     JsonDict,
@@ -85,6 +85,9 @@ from synapse.util.retryutils import NotRetryingDestination
 from synapse.util.stringutils import shortstr
 from synapse.visibility import filter_events_for_server
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -116,7 +119,7 @@ class FederationHandler(BaseHandler):
         rooms.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.hs = hs
@@ -126,6 +129,7 @@ class FederationHandler(BaseHandler):
         self.state_store = self.storage.state
         self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
+        self._state_resolution_handler = hs.get_state_resolution_handler()
         self.server_name = hs.hostname
         self.keyring = hs.get_keyring()
         self.action_generator = hs.get_action_generator()
@@ -381,8 +385,7 @@ class FederationHandler(BaseHandler):
                                 event_map[x.event_id] = x
 
                     room_version = await self.store.get_room_version_id(room_id)
-                    state_map = await resolve_events_with_store(
-                        self.clock,
+                    state_map = await self._state_resolution_handler.resolve_events_with_store(
                         room_id,
                         room_version,
                         state_maps,
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 5a5ea39e01..98ede2ea4f 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -449,8 +449,7 @@ class StateHandler:
         state_map = {ev.event_id: ev for st in state_sets for ev in st}
 
         with Measure(self.clock, "state._resolve_events"):
-            new_state = await resolve_events_with_store(
-                self.clock,
+            new_state = await self._state_resolution_handler.resolve_events_with_store(
                 event.room_id,
                 room_version,
                 state_set_ids,
@@ -531,8 +530,7 @@ class StateResolutionHandler:
             state_groups_histogram.observe(len(state_groups_ids))
 
             with Measure(self.clock, "state._resolve_events"):
-                new_state = await resolve_events_with_store(
-                    self.clock,
+                new_state = await self.resolve_events_with_store(
                     room_id,
                     room_version,
                     list(state_groups_ids.values()),
@@ -552,6 +550,51 @@ class StateResolutionHandler:
 
             return cache
 
+    def resolve_events_with_store(
+        self,
+        room_id: str,
+        room_version: str,
+        state_sets: Sequence[StateMap[str]],
+        event_map: Optional[Dict[str, EventBase]],
+        state_res_store: "StateResolutionStore",
+    ) -> Awaitable[StateMap[str]]:
+        """
+        Args:
+            room_id: the room we are working in
+
+            room_version: Version of the room
+
+            state_sets: List of dicts of (type, state_key) -> event_id,
+                which are the different state groups to resolve.
+
+            event_map:
+                a dict from event_id to event, for any events that we happen to
+                have in flight (eg, those currently being persisted). This will be
+                used as a starting point fof finding the state we need; any missing
+                events will be requested via state_map_factory.
+
+                If None, all events will be fetched via state_res_store.
+
+            state_res_store: a place to fetch events from
+
+        Returns:
+            a map from (type, state_key) to event_id.
+        """
+        v = KNOWN_ROOM_VERSIONS[room_version]
+        if v.state_res == StateResolutionVersions.V1:
+            return v1.resolve_events_with_store(
+                room_id, state_sets, event_map, state_res_store.get_events
+            )
+        else:
+            return v2.resolve_events_with_store(
+                self.clock,
+                room_id,
+                room_version,
+                state_sets,
+                event_map,
+                state_res_store,
+            )
+
 
 def _make_state_cache_entry(
     new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
@@ -605,47 +648,6 @@ def _make_state_cache_entry(
     )
 
 
-def resolve_events_with_store(
-    clock: Clock,
-    room_id: str,
-    room_version: str,
-    state_sets: Sequence[StateMap[str]],
-    event_map: Optional[Dict[str, EventBase]],
-    state_res_store: "StateResolutionStore",
-) -> Awaitable[StateMap[str]]:
-    """
-    Args:
-        room_id: the room we are working in
-
-        room_version: Version of the room
-
-        state_sets: List of dicts of (type, state_key) -> event_id,
-            which are the different state groups to resolve.
-
-        event_map:
-            a dict from event_id to event, for any events that we happen to
-            have in flight (eg, those currently being persisted). This will be
-            used as a starting point fof finding the state we need; any missing
-            events will be requested via state_map_factory.
-
-            If None, all events will be fetched via state_res_store.
-
-        state_res_store: a place to fetch events from
-
-    Returns:
-        a map from (type, state_key) to event_id.
-    """
-    v = KNOWN_ROOM_VERSIONS[room_version]
-    if v.state_res == StateResolutionVersions.V1:
-        return v1.resolve_events_with_store(
-            room_id, state_sets, event_map, state_res_store.get_events
-        )
-    else:
-        return v2.resolve_events_with_store(
-            clock, room_id, room_version, state_sets, event_map, state_res_store
-        )
-
-
 @attr.s(slots=True)
 class StateResolutionStore:
     """Interface that allows state resolution algorithms to access the database