diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 08a69f2f96..31ccbf23dc 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -12,9 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
import attr
@@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.storage.databases import Databases
+
logger = logging.getLogger(__name__)
# Used for generic functions below
@@ -330,10 +343,12 @@ class StateGroupStorage:
"""High level interface to fetching state for event.
"""
- def __init__(self, hs, stores):
+ def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores
- async def get_state_group_delta(self, state_group: int):
+ async def get_state_group_delta(
+ self, state_group: int
+ ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
"""Given a state group try to return a previous group and a delta between
the old and the new.
@@ -341,8 +356,8 @@ class StateGroupStorage:
state_group: The state group used to retrieve state deltas.
Returns:
- Tuple[Optional[int], Optional[StateMap[str]]]:
- (prev_group, delta_ids)
+ A tuple of the previous group and a state map of the event IDs which
+ make up the delta between the old and new state groups.
"""
return await self.stores.state.get_state_group_delta(state_group)
@@ -436,7 +451,7 @@ class StateGroupStorage:
async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
@@ -472,7 +487,7 @@ class StateGroupStorage:
async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
@@ -500,7 +515,7 @@ class StateGroupStorage:
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
@@ -516,7 +531,7 @@ class StateGroupStorage:
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
|