summary refs log tree commit diff
path: root/synapse/storage/controllers/state.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-06-06 11:24:12 +0300
committerGitHub <noreply@github.com>2022-06-06 09:24:12 +0100
commite3163e2e11cf8bffa4cb3e58ac0b86a83eca314c (patch)
tree639271e4b4d4157b95047f801f6b22bf39a24f08 /synapse/storage/controllers/state.py
parentRemove groups code from synapse_port_db. (#12899) (diff)
downloadsynapse-e3163e2e11cf8bffa4cb3e58ac0b86a83eca314c.tar.xz
Reduce the amount of state we pull from the DB (#12811)
Diffstat (limited to 'synapse/storage/controllers/state.py')
-rw-r--r--synapse/storage/controllers/state.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 9952b00493..63a78ebc87 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -455,3 +455,30 @@ class StateStorageController:
         return await self.stores.main.get_partial_current_state_deltas(
             prev_stream_id, max_stream_id
         )
+
+    async def get_current_state(
+        self, room_id: str, state_filter: Optional[StateFilter] = None
+    ) -> StateMap[EventBase]:
+        """Same as `get_current_state_ids` but also fetches the events"""
+        state_map_ids = await self.get_current_state_ids(room_id, state_filter)
+
+        event_map = await self.stores.main.get_events(list(state_map_ids.values()))
+
+        state_map = {}
+        for key, event_id in state_map_ids.items():
+            event = event_map.get(event_id)
+            if event:
+                state_map[key] = event
+
+        return state_map
+
+    async def get_current_state_event(
+        self, room_id: str, event_type: str, state_key: str
+    ) -> Optional[EventBase]:
+        """Get the current state event for the given type/state_key."""
+
+        key = (event_type, state_key)
+        state_map = await self.get_current_state(
+            room_id, StateFilter.from_types((key,))
+        )
+        return state_map.get(key)