diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 9ccd24b298..7a91544119 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -24,6 +24,7 @@ from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore
+ from synapse.storage.state import StateFilter
@attr.s(slots=True, auto_attribs=True)
@@ -196,7 +197,9 @@ class EventContext:
return self._state_group
- async def get_current_state_ids(self) -> Optional[StateMap[str]]:
+ async def get_current_state_ids(
+ self, state_filter: Optional["StateFilter"] = None
+ ) -> Optional[StateMap[str]]:
"""
Gets the room state map, including this event - ie, the state in ``state_group``
@@ -204,6 +207,9 @@ class EventContext:
not make it into the room state. This method will raise an exception if
``rejected`` is set.
+ Arg:
+ state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
+
Returns:
Returns None if state_group is None, which happens when the associated
event is an outlier.
@@ -216,7 +222,7 @@ class EventContext:
assert self._state_delta_due_to_event is not None
- prev_state_ids = await self.get_prev_state_ids()
+ prev_state_ids = await self.get_prev_state_ids(state_filter)
if self._state_delta_due_to_event:
prev_state_ids = dict(prev_state_ids)
@@ -224,12 +230,17 @@ class EventContext:
return prev_state_ids
- async def get_prev_state_ids(self) -> StateMap[str]:
+ async def get_prev_state_ids(
+ self, state_filter: Optional["StateFilter"] = None
+ ) -> StateMap[str]:
"""
Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids().
+ Args:
+ state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
+
Returns:
Returns {} if state_group is None, which happens when the associated
event is an outlier.
@@ -239,7 +250,7 @@ class EventContext:
"""
assert self.state_group_before_event is not None
return await self._storage.state.get_state_ids_for_group(
- self.state_group_before_event
+ self.state_group_before_event, state_filter
)
|