summary refs log tree commit diff
path: root/synapse/events/snapshot.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events/snapshot.py')
-rw-r--r--synapse/events/snapshot.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 9ccd24b298..7a91544119 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -24,6 +24,7 @@ from synapse.types import JsonDict, StateMap
 if TYPE_CHECKING:
     from synapse.storage import Storage
     from synapse.storage.databases.main import DataStore
+    from synapse.storage.state import StateFilter
 
 
 @attr.s(slots=True, auto_attribs=True)
@@ -196,7 +197,9 @@ class EventContext:
 
         return self._state_group
 
-    async def get_current_state_ids(self) -> Optional[StateMap[str]]:
+    async def get_current_state_ids(
+        self, state_filter: Optional["StateFilter"] = None
+    ) -> Optional[StateMap[str]]:
         """
         Gets the room state map, including this event - ie, the state in ``state_group``
 
@@ -204,6 +207,9 @@ class EventContext:
         not make it into the room state. This method will raise an exception if
         ``rejected`` is set.
 
+        Arg:
+           state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
+
         Returns:
             Returns None if state_group is None, which happens when the associated
             event is an outlier.
@@ -216,7 +222,7 @@ class EventContext:
 
         assert self._state_delta_due_to_event is not None
 
-        prev_state_ids = await self.get_prev_state_ids()
+        prev_state_ids = await self.get_prev_state_ids(state_filter)
 
         if self._state_delta_due_to_event:
             prev_state_ids = dict(prev_state_ids)
@@ -224,12 +230,17 @@ class EventContext:
 
         return prev_state_ids
 
-    async def get_prev_state_ids(self) -> StateMap[str]:
+    async def get_prev_state_ids(
+        self, state_filter: Optional["StateFilter"] = None
+    ) -> StateMap[str]:
         """
         Gets the room state map, excluding this event.
 
         For a non-state event, this will be the same as get_current_state_ids().
 
+        Args:
+            state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
+
         Returns:
             Returns {} if state_group is None, which happens when the associated
             event is an outlier.
@@ -239,7 +250,7 @@ class EventContext:
         """
         assert self.state_group_before_event is not None
         return await self._storage.state.get_state_ids_for_group(
-            self.state_group_before_event
+            self.state_group_before_event, state_filter
         )