diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index f94cdcbaba..cca93e3a46 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,17 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Union
+from typing import TYPE_CHECKING, Optional, Union
import attr
from frozendict import frozendict
-from twisted.internet import defer
-
from synapse.appservice import ApplicationService
+from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap
+if TYPE_CHECKING:
+ from synapse.storage.data_stores.main import DataStore
+
@attr.s(slots=True)
class EventContext:
@@ -129,8 +131,7 @@ class EventContext:
delta_ids=delta_ids,
)
- @defer.inlineCallbacks
- def serialize(self, event, store):
+ async def serialize(self, event: EventBase, store: "DataStore") -> dict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -146,7 +147,7 @@ class EventContext:
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
- prev_state_ids = yield self.get_prev_state_ids()
+ prev_state_ids = await self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None
@@ -214,8 +215,7 @@ class EventContext:
return self._state_group
- @defer.inlineCallbacks
- def get_current_state_ids(self):
+ async def get_current_state_ids(self) -> Optional[StateMap[str]]:
"""
Gets the room state map, including this event - ie, the state in ``state_group``
@@ -224,32 +224,31 @@ class EventContext:
``rejected`` is set.
Returns:
- Deferred[dict[(str, str), str]|None]: Returns None if state_group
- is None, which happens when the associated event is an outlier.
+ Returns None if state_group is None, which happens when the associated
+ event is an outlier.
- Maps a (type, state_key) to the event ID of the state event matching
- this tuple.
+ Maps a (type, state_key) to the event ID of the state event matching
+ this tuple.
"""
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")
- yield self._ensure_fetched()
+ await self._ensure_fetched()
return self._current_state_ids
- @defer.inlineCallbacks
- def get_prev_state_ids(self):
+ async def get_prev_state_ids(self):
"""
Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids().
Returns:
- Deferred[dict[(str, str), str]|None]: Returns None if state_group
+ dict[(str, str), str]|None: Returns None if state_group
is None, which happens when the associated event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
- yield self._ensure_fetched()
+ await self._ensure_fetched()
return self._prev_state_ids
def get_cached_current_state_ids(self):
@@ -269,8 +268,8 @@ class EventContext:
return self._current_state_ids
- def _ensure_fetched(self):
- return defer.succeed(None)
+ async def _ensure_fetched(self):
+ return None
@attr.s(slots=True)
@@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext):
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)
- def _ensure_fetched(self):
+ async def _ensure_fetched(self):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state)
- return make_deferred_yieldable(self._fetching_state_deferred)
+ return await make_deferred_yieldable(self._fetching_state_deferred)
- @defer.inlineCallbacks
- def _fill_out_state(self):
+ async def _fill_out_state(self):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return
- self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
+ self._current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group
)
if self._event_state_key is not None:
|