summary refs log tree commit diff
path: root/synapse/storage/databases/main/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/state.py')
-rw-r--r--synapse/storage/databases/main/state.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py

index 96e0378e50..5c6168e301 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter +from synapse.types import StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=100000, iterable=True) - def get_current_state_ids(self, room_id): + async def get_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the current_state_events table. Args: - room_id (str) + room_id: The room to get the state IDs of. Returns: - deferred: dict of (type, state_key) -> event_id + The current state of the room. """ def _get_current_state_ids_txn(txn): @@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_ids", _get_current_state_ids_txn ) # FIXME: how should this be cached? - def get_filtered_current_state_ids( + async def get_filtered_current_state_ids( self, room_id: str, state_filter: StateFilter = StateFilter.all() - ): + ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state @@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): from the database. Returns: - defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. + Map from type/state_key to event ID. """ where_clause, where_args = state_filter.make_sql_filter_clause() if not where_clause: # We delegate to the cached version - return self.get_current_state_ids(room_id) + return await self.get_current_state_ids(room_id) def _get_filtered_current_state_ids_txn(txn): results = {} @@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) @@ -260,8 +261,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return event.content.get("canonical_alias") @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: + return await self.db_pool.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", @@ -273,12 +274,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): cached_method_name="_get_state_group_for_event", list_name="event_ids", num_args=1, - inlineCallbacks=True, ) - def _get_state_group_for_events(self, event_ids): + async def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids,