summary refs log tree commit diff
path: root/synapse/storage/persist_events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/persist_events.py')
-rw-r--r--synapse/storage/persist_events.py123
1 files changed, 68 insertions, 55 deletions
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 1ed44925fc..368c457321 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -17,19 +17,24 @@
 
 import logging
 from collections import deque, namedtuple
+from typing import Iterable, List, Optional, Tuple
 
 from six import iteritems
 from six.moves import range
 
+import attr
 from prometheus_client import Counter, Histogram
 
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.state import StateResolutionStore
 from synapse.storage.data_stores import DataStores
+from synapse.types import StateMap
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
 
@@ -67,6 +72,19 @@ stale_forward_extremities_counter = Histogram(
 )
 
 
+@attr.s(slots=True, frozen=True)
+class DeltaState:
+    """Deltas to use to update the `current_state_events` table.
+
+    Attributes:
+        to_delete: List of type/state_keys to delete from current state
+        to_insert: Map of state to upsert into current state
+    """
+
+    to_delete = attr.ib(type=List[Tuple[str, str]])
+    to_insert = attr.ib(type=StateMap[str])
+
+
 class _EventPeristenceQueue(object):
     """Queues up events so that they can be persisted in bulk with only one
     concurrent transaction per room.
@@ -138,13 +156,12 @@ class _EventPeristenceQueue(object):
 
         self._currently_persisting_rooms.add(room_id)
 
-        @defer.inlineCallbacks
-        def handle_queue_loop():
+        async def handle_queue_loop():
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
                     try:
-                        ret = yield per_item_callback(item)
+                        ret = await per_item_callback(item)
                     except Exception:
                         with PreserveLoggingContext():
                             item.deferred.errback()
@@ -191,12 +208,16 @@ class EventsPersistenceStorage(object):
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
     @defer.inlineCallbacks
-    def persist_events(self, events_and_contexts, backfilled=False):
+    def persist_events(
+        self,
+        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        backfilled: bool = False,
+    ):
         """
         Write events to the database
         Args:
             events_and_contexts: list of tuples of (event, context)
-            backfilled (bool): Whether the results are retrieved from federation
+            backfilled: Whether the results are retrieved from federation
                 via backfill or not. Used to determine if they're "new" events
                 which might update the current state etc.
 
@@ -226,16 +247,12 @@ class EventsPersistenceStorage(object):
         return max_persisted_id
 
     @defer.inlineCallbacks
-    def persist_event(self, event, context, backfilled=False):
+    def persist_event(
+        self, event: FrozenEvent, context: EventContext, backfilled: bool = False
+    ):
         """
-
-        Args:
-            event (EventBase):
-            context (EventContext):
-            backfilled (bool):
-
         Returns:
-            Deferred: resolves to (int, int): the stream ordering of ``event``,
+            Deferred[Tuple[int, int]]: the stream ordering of ``event``,
             and the stream ordering of the latest persisted event
         """
         deferred = self._event_persist_queue.add_to_queue(
@@ -249,28 +266,22 @@ class EventsPersistenceStorage(object):
         max_persisted_id = yield self.main_store.get_current_events_token()
         return (event.internal_metadata.stream_ordering, max_persisted_id)
 
-    def _maybe_start_persisting(self, room_id):
-        @defer.inlineCallbacks
-        def persisting_queue(item):
+    def _maybe_start_persisting(self, room_id: str):
+        async def persisting_queue(item):
             with Measure(self._clock, "persist_events"):
-                yield self._persist_events(
+                await self._persist_events(
                     item.events_and_contexts, backfilled=item.backfilled
                 )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
-    @defer.inlineCallbacks
-    def _persist_events(self, events_and_contexts, backfilled=False):
+    async def _persist_events(
+        self,
+        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        backfilled: bool = False,
+    ):
         """Calculates the change to current state and forward extremities, and
         persists the given events and with those updates.
-
-        Args:
-            events_and_contexts (list[(EventBase, EventContext)]):
-            backfilled (bool):
-            delete_existing (bool):
-
-        Returns:
-            Deferred: resolves when the events have been persisted
         """
         if not events_and_contexts:
             return
@@ -315,10 +326,10 @@ class EventsPersistenceStorage(object):
                         )
 
                     for room_id, ev_ctx_rm in iteritems(events_by_room):
-                        latest_event_ids = yield self.main_store.get_latest_event_ids_in_room(
+                        latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
                             room_id
                         )
-                        new_latest_event_ids = yield self._calculate_new_extremities(
+                        new_latest_event_ids = await self._calculate_new_extremities(
                             room_id, ev_ctx_rm, latest_event_ids
                         )
 
@@ -374,7 +385,7 @@ class EventsPersistenceStorage(object):
                         with Measure(
                             self._clock, "persist_events.get_new_state_after_events"
                         ):
-                            res = yield self._get_new_state_after_events(
+                            res = await self._get_new_state_after_events(
                                 room_id,
                                 ev_ctx_rm,
                                 latest_event_ids,
@@ -389,12 +400,12 @@ class EventsPersistenceStorage(object):
                             # If there is a delta we know that we've
                             # only added or replaced state, never
                             # removed keys entirely.
-                            state_delta_for_room[room_id] = ([], delta_ids)
+                            state_delta_for_room[room_id] = DeltaState([], delta_ids)
                         elif current_state is not None:
                             with Measure(
                                 self._clock, "persist_events.calculate_state_delta"
                             ):
-                                delta = yield self._calculate_state_delta(
+                                delta = await self._calculate_state_delta(
                                     room_id, current_state
                                 )
                             state_delta_for_room[room_id] = delta
@@ -404,7 +415,7 @@ class EventsPersistenceStorage(object):
                         if current_state is not None:
                             current_state_for_room[room_id] = current_state
 
-            yield self.main_store._persist_events_and_state_updates(
+            await self.main_store._persist_events_and_state_updates(
                 chunk,
                 current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
@@ -412,8 +423,12 @@ class EventsPersistenceStorage(object):
                 backfilled=backfilled,
             )
 
-    @defer.inlineCallbacks
-    def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
+    async def _calculate_new_extremities(
+        self,
+        room_id: str,
+        event_contexts: List[Tuple[FrozenEvent, EventContext]],
+        latest_event_ids: List[str],
+    ):
         """Calculates the new forward extremities for a room given events to
         persist.
 
@@ -444,13 +459,13 @@ class EventsPersistenceStorage(object):
         )
 
         # Remove any events which are prev_events of any existing events.
-        existing_prevs = yield self.main_store._get_events_which_are_prevs(result)
+        existing_prevs = await self.main_store._get_events_which_are_prevs(result)
         result.difference_update(existing_prevs)
 
         # Finally handle the case where the new events have soft-failed prev
         # events. If they do we need to remove them and their prev events,
         # otherwise we end up with dangling extremities.
-        existing_prevs = yield self.main_store._get_prevs_before_rejected(
+        existing_prevs = await self.main_store._get_prevs_before_rejected(
             e_id for event in new_events for e_id in event.prev_event_ids()
         )
         result.difference_update(existing_prevs)
@@ -464,10 +479,13 @@ class EventsPersistenceStorage(object):
 
         return result
 
-    @defer.inlineCallbacks
-    def _get_new_state_after_events(
-        self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
-    ):
+    async def _get_new_state_after_events(
+        self,
+        room_id: str,
+        events_context: List[Tuple[FrozenEvent, EventContext]],
+        old_latest_event_ids: Iterable[str],
+        new_latest_event_ids: Iterable[str],
+    ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
         """Calculate the current state dict after adding some new events to
         a room
 
@@ -485,7 +503,6 @@ class EventsPersistenceStorage(object):
                 the new forward extremities for the room.
 
         Returns:
-            Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
             Returns a tuple of two state maps, the first being the full new current
             state and the second being the delta to the existing current state.
             If both are None then there has been no change.
@@ -547,7 +564,7 @@ class EventsPersistenceStorage(object):
 
         if missing_event_ids:
             # Now pull out the state groups for any missing events from DB
-            event_to_groups = yield self.main_store._get_state_group_for_events(
+            event_to_groups = await self.main_store._get_state_group_for_events(
                 missing_event_ids
             )
             event_id_to_state_group.update(event_to_groups)
@@ -588,7 +605,7 @@ class EventsPersistenceStorage(object):
         # their state IDs so we can resolve to a single state set.
         missing_state = new_state_groups - set(state_groups_map)
         if missing_state:
-            group_to_state = yield self.state_store._get_state_for_groups(missing_state)
+            group_to_state = await self.state_store._get_state_for_groups(missing_state)
             state_groups_map.update(group_to_state)
 
         if len(new_state_groups) == 1:
@@ -612,10 +629,10 @@ class EventsPersistenceStorage(object):
                 break
 
         if not room_version:
-            room_version = yield self.main_store.get_room_version(room_id)
+            room_version = await self.main_store.get_room_version(room_id)
 
         logger.debug("calling resolve_state_groups from preserve_events")
-        res = yield self._state_resolution_handler.resolve_state_groups(
+        res = await self._state_resolution_handler.resolve_state_groups(
             room_id,
             room_version,
             state_groups,
@@ -625,18 +642,14 @@ class EventsPersistenceStorage(object):
 
         return res.state, None
 
-    @defer.inlineCallbacks
-    def _calculate_state_delta(self, room_id, current_state):
+    async def _calculate_state_delta(
+        self, room_id: str, current_state: StateMap[str]
+    ) -> DeltaState:
         """Calculate the new state deltas for a room.
 
         Assumes that we are only persisting events for one room at a time.
-
-        Returns:
-            tuple[list, dict] (to_delete, to_insert): where to_delete are the
-            type/state_keys to remove from current_state_events and `to_insert`
-            are the updates to current_state_events.
         """
-        existing_state = yield self.main_store.get_current_state_ids(room_id)
+        existing_state = await self.main_store.get_current_state_ids(room_id)
 
         to_delete = [key for key in existing_state if key not in current_state]
 
@@ -646,4 +659,4 @@ class EventsPersistenceStorage(object):
             if ev_id != existing_state.get(key)
         }
 
-        return to_delete, to_insert
+        return DeltaState(to_delete=to_delete, to_insert=to_insert)