diff options
Diffstat (limited to 'synapse/events/snapshot.py')
-rw-r--r-- | synapse/events/snapshot.py | 46 |
1 files changed, 22 insertions, 24 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index f94cdcbaba..cca93e3a46 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,17 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import attr from frozendict import frozendict -from twisted.internet import defer - from synapse.appservice import ApplicationService +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import StateMap +if TYPE_CHECKING: + from synapse.storage.data_stores.main import DataStore + @attr.s(slots=True) class EventContext: @@ -129,8 +131,7 @@ class EventContext: delta_ids=delta_ids, ) - @defer.inlineCallbacks - def serialize(self, event, store): + async def serialize(self, event: EventBase, store: "DataStore") -> dict: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -146,7 +147,7 @@ class EventContext: # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_ids = yield self.get_prev_state_ids() + prev_state_ids = await self.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None @@ -214,8 +215,7 @@ class EventContext: return self._state_group - @defer.inlineCallbacks - def get_current_state_ids(self): + async def get_current_state_ids(self) -> Optional[StateMap[str]]: """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -224,32 +224,31 @@ class EventContext: ``rejected`` is set. Returns: - Deferred[dict[(str, str), str]|None]: Returns None if state_group - is None, which happens when the associated event is an outlier. + Returns None if state_group is None, which happens when the associated + event is an outlier. - Maps a (type, state_key) to the event ID of the state event matching - this tuple. + Maps a (type, state_key) to the event ID of the state event matching + this tuple. """ if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - yield self._ensure_fetched() + await self._ensure_fetched() return self._current_state_ids - @defer.inlineCallbacks - def get_prev_state_ids(self): + async def get_prev_state_ids(self): """ Gets the room state map, excluding this event. For a non-state event, this will be the same as get_current_state_ids(). Returns: - Deferred[dict[(str, str), str]|None]: Returns None if state_group + dict[(str, str), str]|None: Returns None if state_group is None, which happens when the associated event is an outlier. Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - yield self._ensure_fetched() + await self._ensure_fetched() return self._prev_state_ids def get_cached_current_state_ids(self): @@ -269,8 +268,8 @@ class EventContext: return self._current_state_ids - def _ensure_fetched(self): - return defer.succeed(None) + async def _ensure_fetched(self): + return None @attr.s(slots=True) @@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext): _event_state_key = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None) - def _ensure_fetched(self): + async def _ensure_fetched(self): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background(self._fill_out_state) - return make_deferred_yieldable(self._fetching_state_deferred) + return await make_deferred_yieldable(self._fetching_state_deferred) - @defer.inlineCallbacks - def _fill_out_state(self): + async def _fill_out_state(self): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield self._storage.state.get_state_ids_for_group( + self._current_state_ids = await self._storage.state.get_state_ids_for_group( self.state_group ) if self._event_state_key is not None: |