summary refs log tree commit diff
path: root/synapse/storage/databases/main/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/events.py')
-rw-r--r--synapse/storage/databases/main/events.py33
1 files changed, 14 insertions, 19 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1a68bf32cb..b90e6de2d5 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,13 +17,11 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
 
 import attr
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
@@ -113,15 +111,14 @@ class PersistEventsStore:
             hs.config.worker.writers.events == hs.get_instance_name()
         ), "Can only instantiate EventsStore on master"
 
-    @defer.inlineCallbacks
-    def _persist_events_and_state_updates(
+    async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremeties: Dict[str, List[str]],
         backfilled: bool = False,
-    ):
+    ) -> None:
         """Persist a set of events alongside updates to the current state and
         forward extremities tables.
 
@@ -136,7 +133,7 @@ class PersistEventsStore:
             backfilled
 
         Returns:
-            Deferred: resolves when the events have been persisted
+            Resolves when the events have been persisted
         """
 
         # We want to calculate the stream orderings as late as possible, as
@@ -168,7 +165,7 @@ class PersistEventsStore:
             for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "persist_events",
                 self._persist_events_txn,
                 events_and_contexts=events_and_contexts,
@@ -206,16 +203,15 @@ class PersistEventsStore:
                     (room_id,), list(latest_event_ids)
                 )
 
-    @defer.inlineCallbacks
-    def _get_events_which_are_prevs(self, event_ids):
+    async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
         """Filter the supplied list of event_ids to get those which are prev_events of
         existing (non-outlier/rejected) events.
 
         Args:
-            event_ids (Iterable[str]): event ids to filter
+            event_ids: event ids to filter
 
         Returns:
-            Deferred[List[str]]: filtered event ids
+            Filtered event ids
         """
         results = []
 
@@ -240,14 +236,13 @@ class PersistEventsStore:
             results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
             )
 
         return results
 
-    @defer.inlineCallbacks
-    def _get_prevs_before_rejected(self, event_ids):
+    async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
         """Get soft-failed ancestors to remove from the extremities.
 
         Given a set of events, find all those that have been soft-failed or
@@ -259,11 +254,11 @@ class PersistEventsStore:
         are separated by soft failed events.
 
         Args:
-            event_ids (Iterable[str]): Events to find prev events for. Note
-                that these must have already been persisted.
+            event_ids: Events to find prev events for. Note that these must have
+                already been persisted.
 
         Returns:
-            Deferred[set[str]]
+            The previous events.
         """
 
         # The set of event_ids to return. This includes all soft-failed events
@@ -304,7 +299,7 @@ class PersistEventsStore:
                         existing_prevs.add(prev_event_id)
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
             )