diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index bb38a04ede..a360699408 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -16,12 +16,12 @@
import collections.abc
import logging
from collections import namedtuple
-
-from twisted.internet import defer
+from typing import Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
@@ -108,28 +108,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = await self.get_create_event_for_room(room_id)
return create_event.content.get("room_version", "1")
- @defer.inlineCallbacks
- def get_room_predecessor(self, room_id):
+ async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
Args:
- room_id (str)
+ room_id: The room ID.
Returns:
- Deferred[dict|None]: A dictionary containing the structure of the predecessor
- field from the room's create event. The structure is subject to other servers,
- but it is expected to be:
- * room_id (str): The room ID of the predecessor room
- * event_id (str): The ID of the tombstone event in the predecessor room
+ A dictionary containing the structure of the predecessor
+ field from the room's create event. The structure is subject to other servers,
+ but it is expected to be:
+ * room_id (str): The room ID of the predecessor room
+ * event_id (str): The ID of the tombstone event in the predecessor room
- None if a predecessor key is not found, or is not a dictionary.
+ None if a predecessor key is not found, or is not a dictionary.
Raises:
NotFoundError if the given room is unknown
"""
# Retrieve the room's create event
- create_event = yield self.get_create_event_for_room(room_id)
+ create_event = await self.get_create_event_for_room(room_id)
# Retrieve the predecessor key of the create event
predecessor = create_event.content.get("predecessor", None)
@@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return predecessor
- @defer.inlineCallbacks
- def get_create_event_for_room(self, room_id):
+ async def get_create_event_for_room(self, room_id: str) -> EventBase:
"""Get the create state event for a room.
Args:
- room_id (str)
+ room_id: The room ID.
Returns:
- Deferred[EventBase]: The room creation event.
+ The room creation event.
Raises:
NotFoundError if the room is unknown
"""
- state_ids = yield self.get_current_state_ids(room_id)
+ state_ids = await self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, ""))
# If we can't find the create event, assume we've hit a dead end
@@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
raise NotFoundError("Unknown room %s" % (room_id,))
# Retrieve the room's create event and return
- create_event = yield self.get_event(create_id)
+ create_event = await self.get_event(create_id)
return create_event
@cached(max_entries=100000, iterable=True)
@@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
- @defer.inlineCallbacks
- def get_canonical_alias_for_room(self, room_id):
+ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
"""Get canonical alias for room, if any
Args:
- room_id (str)
+ room_id: The room ID
Returns:
- Deferred[str|None]: The canonical alias, if any
+ The canonical alias, if any
"""
- state = yield self.get_filtered_current_state_ids(
+ state = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
)
@@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event_id:
return
- event = yield self.get_event(event_id, allow_none=True)
+ event = await self.get_event(event_id, allow_none=True)
if not event:
return
@@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {row["event_id"]: row["state_group"] for row in rows}
- @defer.inlineCallbacks
- def get_referenced_state_groups(self, state_groups):
+ async def get_referenced_state_groups(
+ self, state_groups: Iterable[int]
+ ) -> Set[int]:
"""Check if the state groups are referenced by events.
Args:
- state_groups (Iterable[int])
+ state_groups
Returns:
- Deferred[set[int]]: The subset of state groups that are
- referenced.
+ The subset of state groups that are referenced.
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db.simple_select_many_batch(
table="event_to_state_groups",
column="state_group",
iterable=state_groups,
|