summary refs log tree commit diff
path: root/synapse/storage/databases/main/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/state.py')
-rw-r--r--synapse/storage/databases/main/state.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 458f169617..5c6168e301 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateFilter
+from synapse.types import StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         return create_event
 
     @cached(max_entries=100000, iterable=True)
-    def get_current_state_ids(self, room_id):
+    async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
         """Get the current state event ids for a room based on the
         current_state_events table.
 
         Args:
-            room_id (str)
+            room_id: The room to get the state IDs of.
 
         Returns:
-            deferred: dict of (type, state_key) -> event_id
+            The current state of the room.
         """
 
         def _get_current_state_ids_txn(txn):
@@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_current_state_ids", _get_current_state_ids_txn
         )
 
     # FIXME: how should this be cached?
-    def get_filtered_current_state_ids(
+    async def get_filtered_current_state_ids(
         self, room_id: str, state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> StateMap[str]:
         """Get the current state event of a given type for a room based on the
         current_state_events table.  This may not be as up-to-date as the result
         of doing a fresh state resolution as per state_handler.get_current_state
@@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 from the database.
 
         Returns:
-            defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
+            Map from type/state_key to event ID.
         """
 
         where_clause, where_args = state_filter.make_sql_filter_clause()
 
         if not where_clause:
             # We delegate to the cached version
-            return self.get_current_state_ids(room_id)
+            return await self.get_current_state_ids(room_id)
 
         def _get_filtered_current_state_ids_txn(txn):
             results = {}
@@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return results
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
         )