diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 96e0378e50..5c6168e301 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
+from synapse.types import StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return create_event
@cached(max_entries=100000, iterable=True)
- def get_current_state_ids(self, room_id):
+ async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
Args:
- room_id (str)
+ room_id: The room to get the state IDs of.
Returns:
- deferred: dict of (type, state_key) -> event_id
+ The current state of the room.
"""
def _get_current_state_ids_txn(txn):
@@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
)
# FIXME: how should this be cached?
- def get_filtered_current_state_ids(
+ async def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
@@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
from the database.
Returns:
- defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
+ Map from type/state_key to event ID.
"""
where_clause, where_args = state_filter.make_sql_filter_clause()
if not where_clause:
# We delegate to the cached version
- return self.get_current_state_ids(room_id)
+ return await self.get_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(txn):
results = {}
@@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
@@ -260,8 +261,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return event.content.get("canonical_alias")
@cached(max_entries=50000)
- def _get_state_group_for_event(self, event_id):
- return self.db_pool.simple_select_one_onecol(
+ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
+ return await self.db_pool.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
@@ -273,12 +274,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
- inlineCallbacks=True,
)
- def _get_state_group_for_events(self, event_ids):
+ async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
|