summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-05-20 12:32:56 +0100
committerErik Johnston <erik@matrix.org>2022-05-20 12:51:53 +0100
commitf69785e875872ce2b8a96c4e5b55ca17ef862438 (patch)
treea7ed4ff4d02eb29ff601ab21632fb693bf1536e9
parentUse new store.get_current_state_event (diff)
downloadsynapse-f69785e875872ce2b8a96c4e5b55ca17ef862438.tar.xz
Add helper methods to store
-rw-r--r--synapse/storage/databases/main/state.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index c81fadbe9e..b8a277ae31 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -218,6 +218,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             "get_current_state_ids", _get_current_state_ids_txn
         )
 
+    async def get_current_state(self, room_id: str) -> StateMap[EventBase]:
+        """Same as `get_current_state_ids` but also fetches the events"""
+        state_map_ids = await self.get_current_state_ids(room_id)
+
+        event_map = await self.get_events(list(state_map_ids.values()))
+
+        state_map = {}
+        for key, event_id in state_map_ids.items():
+            event = event_map.get(event_id)
+            if event:
+                state_map[key] = event
+
+        return state_map
+
     # FIXME: how should this be cached?
     async def get_filtered_current_state_ids(
         self, room_id: str, state_filter: Optional[StateFilter] = None
@@ -269,6 +283,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
         )
 
+    async def get_filtered_current_state(
+        self, room_id: str, state_filter: Optional[StateFilter] = None
+    ) -> StateMap[EventBase]:
+        """Same as `get_filtered_current_state_ids` but also fetches the events"""
+        state_map_ids = await self.get_filtered_current_state_ids(room_id, state_filter)
+
+        event_map = await self.get_events(list(state_map_ids.values()))
+
+        state_map = {}
+        for key, event_id in state_map_ids.items():
+            event = event_map.get(event_id)
+            if event:
+                state_map[key] = event
+
+        return state_map
+
     async def get_current_state_event(
         self, room_id: str, event_type: str, state_key: str
     ) -> Optional[EventBase]: