summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/state.py35
1 files changed, 25 insertions, 10 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 08a69f2f96..31ccbf23dc 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -12,9 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+)
 
 import attr
 
@@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.storage.databases import Databases
+
 logger = logging.getLogger(__name__)
 
 # Used for generic functions below
@@ -330,10 +343,12 @@ class StateGroupStorage:
     """High level interface to fetching state for event.
     """
 
-    def __init__(self, hs, stores):
+    def __init__(self, hs: "HomeServer", stores: "Databases"):
         self.stores = stores
 
-    async def get_state_group_delta(self, state_group: int):
+    async def get_state_group_delta(
+        self, state_group: int
+    ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
@@ -341,8 +356,8 @@ class StateGroupStorage:
             state_group: The state group used to retrieve state deltas.
 
         Returns:
-            Tuple[Optional[int], Optional[StateMap[str]]]:
-                (prev_group, delta_ids)
+            A tuple of the previous group and a state map of the event IDs which
+            make up the delta between the old and new state groups.
         """
 
         return await self.stores.state.get_state_group_delta(state_group)
@@ -436,7 +451,7 @@ class StateGroupStorage:
 
     async def get_state_for_events(
         self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> Dict[str, StateMap[EventBase]]:
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
 
@@ -472,7 +487,7 @@ class StateGroupStorage:
 
     async def get_state_ids_for_events(
         self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
         of the state events (as opposed to the events themselves)
@@ -500,7 +515,7 @@ class StateGroupStorage:
 
     async def get_state_for_event(
         self, event_id: str, state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> StateMap[EventBase]:
         """
         Get the state dict corresponding to a particular event
 
@@ -516,7 +531,7 @@ class StateGroupStorage:
 
     async def get_state_ids_for_event(
         self, event_id: str, state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> StateMap[str]:
         """
         Get the state dict corresponding to a particular event