diff options
author | Shay <hillerys@element.io> | 2022-05-20 01:54:12 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-20 09:54:12 +0100 |
commit | 71e8afe34d2103c5ccc9f2d1c99587d14b2acc56 (patch) | |
tree | 38ca83b911323cd5312165467e2b8bf077987d2a /synapse/events | |
parent | Fix `RetryDestinationLimiter` re-starting finished log contexts (#12803) (diff) | |
download | synapse-71e8afe34d2103c5ccc9f2d1c99587d14b2acc56.tar.xz |
Update EventContext `get_current_event_ids` and `get_prev_event_ids` to accept state filters and update calls where possible (#12791)
Diffstat (limited to 'synapse/events')
-rw-r--r-- | synapse/events/snapshot.py | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 9ccd24b298..7a91544119 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -24,6 +24,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage import Storage from synapse.storage.databases.main import DataStore + from synapse.storage.state import StateFilter @attr.s(slots=True, auto_attribs=True) @@ -196,7 +197,9 @@ class EventContext: return self._state_group - async def get_current_state_ids(self) -> Optional[StateMap[str]]: + async def get_current_state_ids( + self, state_filter: Optional["StateFilter"] = None + ) -> Optional[StateMap[str]]: """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -204,6 +207,9 @@ class EventContext: not make it into the room state. This method will raise an exception if ``rejected`` is set. + Arg: + state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules + Returns: Returns None if state_group is None, which happens when the associated event is an outlier. @@ -216,7 +222,7 @@ class EventContext: assert self._state_delta_due_to_event is not None - prev_state_ids = await self.get_prev_state_ids() + prev_state_ids = await self.get_prev_state_ids(state_filter) if self._state_delta_due_to_event: prev_state_ids = dict(prev_state_ids) @@ -224,12 +230,17 @@ class EventContext: return prev_state_ids - async def get_prev_state_ids(self) -> StateMap[str]: + async def get_prev_state_ids( + self, state_filter: Optional["StateFilter"] = None + ) -> StateMap[str]: """ Gets the room state map, excluding this event. For a non-state event, this will be the same as get_current_state_ids(). + Args: + state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules + Returns: Returns {} if state_group is None, which happens when the associated event is an outlier. @@ -239,7 +250,7 @@ class EventContext: """ assert self.state_group_before_event is not None return await self._storage.state.get_state_ids_for_group( - self.state_group_before_event + self.state_group_before_event, state_filter ) |