summary refs log tree commit diff
path: root/synapse/events/snapshot.py
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2022-05-20 01:54:12 -0700
committerGitHub <noreply@github.com>2022-05-20 09:54:12 +0100
commit71e8afe34d2103c5ccc9f2d1c99587d14b2acc56 (patch)
tree38ca83b911323cd5312165467e2b8bf077987d2a /synapse/events/snapshot.py
parentFix `RetryDestinationLimiter` re-starting finished log contexts (#12803) (diff)
downloadsynapse-71e8afe34d2103c5ccc9f2d1c99587d14b2acc56.tar.xz
Update EventContext `get_current_event_ids` and `get_prev_event_ids` to accept state filters and update calls where possible (#12791)
Diffstat (limited to 'synapse/events/snapshot.py')
-rw-r--r--synapse/events/snapshot.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py

index 9ccd24b298..7a91544119 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py
@@ -24,6 +24,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage import Storage from synapse.storage.databases.main import DataStore + from synapse.storage.state import StateFilter @attr.s(slots=True, auto_attribs=True) @@ -196,7 +197,9 @@ class EventContext: return self._state_group - async def get_current_state_ids(self) -> Optional[StateMap[str]]: + async def get_current_state_ids( + self, state_filter: Optional["StateFilter"] = None + ) -> Optional[StateMap[str]]: """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -204,6 +207,9 @@ class EventContext: not make it into the room state. This method will raise an exception if ``rejected`` is set. + Arg: + state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules + Returns: Returns None if state_group is None, which happens when the associated event is an outlier. @@ -216,7 +222,7 @@ class EventContext: assert self._state_delta_due_to_event is not None - prev_state_ids = await self.get_prev_state_ids() + prev_state_ids = await self.get_prev_state_ids(state_filter) if self._state_delta_due_to_event: prev_state_ids = dict(prev_state_ids) @@ -224,12 +230,17 @@ class EventContext: return prev_state_ids - async def get_prev_state_ids(self) -> StateMap[str]: + async def get_prev_state_ids( + self, state_filter: Optional["StateFilter"] = None + ) -> StateMap[str]: """ Gets the room state map, excluding this event. For a non-state event, this will be the same as get_current_state_ids(). + Args: + state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules + Returns: Returns {} if state_group is None, which happens when the associated event is an outlier. @@ -239,7 +250,7 @@ class EventContext: """ assert self.state_group_before_event is not None return await self._storage.state.get_state_ids_for_group( - self.state_group_before_event + self.state_group_before_event, state_filter )