summary refs log tree commit diff
path: root/synapse/events/snapshot.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events/snapshot.py')
-rw-r--r--synapse/events/snapshot.py46
1 files changed, 22 insertions, 24 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index f94cdcbaba..afecafe15c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,17 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Optional, Union
+from typing import TYPE_CHECKING, Optional, Union
 
 import attr
 from frozendict import frozendict
 
-from twisted.internet import defer
-
 from synapse.appservice import ApplicationService
+from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.types import StateMap
 
+if TYPE_CHECKING:
+    from synapse.storage.databases.main import DataStore
+
 
 @attr.s(slots=True)
 class EventContext:
@@ -129,8 +131,7 @@ class EventContext:
             delta_ids=delta_ids,
         )
 
-    @defer.inlineCallbacks
-    def serialize(self, event, store):
+    async def serialize(self, event: EventBase, store: "DataStore") -> dict:
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
 
@@ -146,7 +147,7 @@ class EventContext:
         # the prev_state_ids, so if we're a state event we include the event
         # id that we replaced in the state.
         if event.is_state():
-            prev_state_ids = yield self.get_prev_state_ids()
+            prev_state_ids = await self.get_prev_state_ids()
             prev_state_id = prev_state_ids.get((event.type, event.state_key))
         else:
             prev_state_id = None
@@ -214,8 +215,7 @@ class EventContext:
 
         return self._state_group
 
-    @defer.inlineCallbacks
-    def get_current_state_ids(self):
+    async def get_current_state_ids(self) -> Optional[StateMap[str]]:
         """
         Gets the room state map, including this event - ie, the state in ``state_group``
 
@@ -224,32 +224,31 @@ class EventContext:
         ``rejected`` is set.
 
         Returns:
-            Deferred[dict[(str, str), str]|None]: Returns None if state_group
-                is None, which happens when the associated event is an outlier.
+            Returns None if state_group is None, which happens when the associated
+            event is an outlier.
 
-                Maps a (type, state_key) to the event ID of the state event matching
-                this tuple.
+            Maps a (type, state_key) to the event ID of the state event matching
+            this tuple.
         """
         if self.rejected:
             raise RuntimeError("Attempt to access state_ids of rejected event")
 
-        yield self._ensure_fetched()
+        await self._ensure_fetched()
         return self._current_state_ids
 
-    @defer.inlineCallbacks
-    def get_prev_state_ids(self):
+    async def get_prev_state_ids(self):
         """
         Gets the room state map, excluding this event.
 
         For a non-state event, this will be the same as get_current_state_ids().
 
         Returns:
-            Deferred[dict[(str, str), str]|None]: Returns None if state_group
+            dict[(str, str), str]|None: Returns None if state_group
                 is None, which happens when the associated event is an outlier.
                 Maps a (type, state_key) to the event ID of the state event matching
                 this tuple.
         """
-        yield self._ensure_fetched()
+        await self._ensure_fetched()
         return self._prev_state_ids
 
     def get_cached_current_state_ids(self):
@@ -269,8 +268,8 @@ class EventContext:
 
         return self._current_state_ids
 
-    def _ensure_fetched(self):
-        return defer.succeed(None)
+    async def _ensure_fetched(self):
+        return None
 
 
 @attr.s(slots=True)
@@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext):
     _event_state_key = attr.ib(default=None)
     _fetching_state_deferred = attr.ib(default=None)
 
-    def _ensure_fetched(self):
+    async def _ensure_fetched(self):
         if not self._fetching_state_deferred:
             self._fetching_state_deferred = run_in_background(self._fill_out_state)
 
-        return make_deferred_yieldable(self._fetching_state_deferred)
+        return await make_deferred_yieldable(self._fetching_state_deferred)
 
-    @defer.inlineCallbacks
-    def _fill_out_state(self):
+    async def _fill_out_state(self):
         """Called to populate the _current_state_ids and _prev_state_ids
         attributes by loading from the database.
         """
         if self.state_group is None:
             return
 
-        self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
+        self._current_state_ids = await self._storage.state.get_state_ids_for_group(
             self.state_group
         )
         if self._event_state_key is not None: