summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/data_stores/main/events.py87
-rw-r--r--synapse/storage/persist_events.py123
2 files changed, 111 insertions, 99 deletions
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index bb69c20448..596daf8909 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -19,6 +19,7 @@ import itertools
 import logging
 from collections import Counter as c_counter, OrderedDict, namedtuple
 from functools import wraps
+from typing import Dict, List, Tuple
 
 from six import iteritems, text_type
 from six.moves import range
@@ -41,8 +42,9 @@ from synapse.storage._base import make_in_list_sql_clause
 from synapse.storage.data_stores.main.event_federation import EventFederationStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.state import StateGroupWorkerStore
-from synapse.storage.database import Database
-from synapse.types import RoomStreamToken, get_domain_from_id
+from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.persist_events import DeltaState
+from synapse.types import RoomStreamToken, StateMap, get_domain_from_id
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.iterutils import batch_iter
@@ -148,30 +150,26 @@ class EventsStore(
     @defer.inlineCallbacks
     def _persist_events_and_state_updates(
         self,
-        events_and_contexts,
-        current_state_for_room,
-        state_delta_for_room,
-        new_forward_extremeties,
-        backfilled=False,
-        delete_existing=False,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        current_state_for_room: Dict[str, StateMap[str]],
+        state_delta_for_room: Dict[str, DeltaState],
+        new_forward_extremeties: Dict[str, List[str]],
+        backfilled: bool = False,
+        delete_existing: bool = False,
     ):
         """Persist a set of events alongside updates to the current state and
         forward extremities tables.
 
         Args:
-            events_and_contexts (list[(EventBase, EventContext)]):
-            current_state_for_room (dict[str, dict]): Map from room_id to the
-                current state of the room based on forward extremities
-            state_delta_for_room (dict[str, tuple]): Map from room_id to tuple
-                of `(to_delete, to_insert)` where to_delete is a list
-                of type/state keys to remove from current state, and to_insert
-                is a map (type,key)->event_id giving the state delta in each
-                room.
-            new_forward_extremities (dict[str, list[str]]): Map from room_id
-                to list of event IDs that are the new forward extremities of
-                the room.
-            backfilled (bool)
-            delete_existing (bool):
+            events_and_contexts:
+            current_state_for_room: Map from room_id to the current state of
+                the room based on forward extremities
+            state_delta_for_room: Map from room_id to the delta to apply to
+                room state
+            new_forward_extremities: Map from room_id to list of event IDs
+                that are the new forward extremities of the room.
+            backfilled
+            delete_existing
 
         Returns:
             Deferred: resolves when the events have been persisted
@@ -352,12 +350,12 @@ class EventsStore(
     @log_function
     def _persist_events_txn(
         self,
-        txn,
-        events_and_contexts,
-        backfilled,
-        delete_existing=False,
-        state_delta_for_room={},
-        new_forward_extremeties={},
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        backfilled: bool,
+        delete_existing: bool = False,
+        state_delta_for_room: Dict[str, DeltaState] = {},
+        new_forward_extremeties: Dict[str, List[str]] = {},
     ):
         """Insert some number of room events into the necessary database tables.
 
@@ -366,21 +364,16 @@ class EventsStore(
         whether the event was rejected.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]):
-                events to persist
-            backfilled (bool): True if the events were backfilled
-            delete_existing (bool): True to purge existing table rows for the
-                events from the database. This is useful when retrying due to
+            txn
+            events_and_contexts: events to persist
+            backfilled: True if the events were backfilled
+            delete_existing True to purge existing table rows for the events
+                from the database. This is useful when retrying due to
                 IntegrityError.
-            state_delta_for_room (dict[str, (list, dict)]):
-                The current-state delta for each room. For each room, a tuple
-                (to_delete, to_insert), being a list of type/state keys to be
-                removed from the current state, and a state set to be added to
-                the current state.
-            new_forward_extremeties (dict[str, list[str]]):
-                The new forward extremities for each room. For each room, a
-                list of the event ids which are the forward extremities.
+            state_delta_for_room: The current-state delta for each room.
+            new_forward_extremetie: The new forward extremities for each room.
+                For each room, a list of the event ids which are the forward
+                extremities.
 
         """
         all_events_and_contexts = events_and_contexts
@@ -465,9 +458,15 @@ class EventsStore(
         # room_memberships, where applicable.
         self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
 
-    def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
-        for room_id, current_state_tuple in iteritems(state_delta_by_room):
-            to_delete, to_insert = current_state_tuple
+    def _update_current_state_txn(
+        self,
+        txn: LoggingTransaction,
+        state_delta_by_room: Dict[str, DeltaState],
+        stream_id: int,
+    ):
+        for room_id, delta_state in iteritems(state_delta_by_room):
+            to_delete = delta_state.to_delete
+            to_insert = delta_state.to_insert
 
             # First we add entries to the current_state_delta_stream. We
             # do this before updating the current_state_events table so
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 1ed44925fc..368c457321 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -17,19 +17,24 @@
 
 import logging
 from collections import deque, namedtuple
+from typing import Iterable, List, Optional, Tuple
 
 from six import iteritems
 from six.moves import range
 
+import attr
 from prometheus_client import Counter, Histogram
 
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.state import StateResolutionStore
 from synapse.storage.data_stores import DataStores
+from synapse.types import StateMap
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
 
@@ -67,6 +72,19 @@ stale_forward_extremities_counter = Histogram(
 )
 
 
+@attr.s(slots=True, frozen=True)
+class DeltaState:
+    """Deltas to use to update the `current_state_events` table.
+
+    Attributes:
+        to_delete: List of type/state_keys to delete from current state
+        to_insert: Map of state to upsert into current state
+    """
+
+    to_delete = attr.ib(type=List[Tuple[str, str]])
+    to_insert = attr.ib(type=StateMap[str])
+
+
 class _EventPeristenceQueue(object):
     """Queues up events so that they can be persisted in bulk with only one
     concurrent transaction per room.
@@ -138,13 +156,12 @@ class _EventPeristenceQueue(object):
 
         self._currently_persisting_rooms.add(room_id)
 
-        @defer.inlineCallbacks
-        def handle_queue_loop():
+        async def handle_queue_loop():
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
                     try:
-                        ret = yield per_item_callback(item)
+                        ret = await per_item_callback(item)
                     except Exception:
                         with PreserveLoggingContext():
                             item.deferred.errback()
@@ -191,12 +208,16 @@ class EventsPersistenceStorage(object):
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
     @defer.inlineCallbacks
-    def persist_events(self, events_and_contexts, backfilled=False):
+    def persist_events(
+        self,
+        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        backfilled: bool = False,
+    ):
         """
         Write events to the database
         Args:
             events_and_contexts: list of tuples of (event, context)
-            backfilled (bool): Whether the results are retrieved from federation
+            backfilled: Whether the results are retrieved from federation
                 via backfill or not. Used to determine if they're "new" events
                 which might update the current state etc.
 
@@ -226,16 +247,12 @@ class EventsPersistenceStorage(object):
         return max_persisted_id
 
     @defer.inlineCallbacks
-    def persist_event(self, event, context, backfilled=False):
+    def persist_event(
+        self, event: FrozenEvent, context: EventContext, backfilled: bool = False
+    ):
         """
-
-        Args:
-            event (EventBase):
-            context (EventContext):
-            backfilled (bool):
-
         Returns:
-            Deferred: resolves to (int, int): the stream ordering of ``event``,
+            Deferred[Tuple[int, int]]: the stream ordering of ``event``,
             and the stream ordering of the latest persisted event
         """
         deferred = self._event_persist_queue.add_to_queue(
@@ -249,28 +266,22 @@ class EventsPersistenceStorage(object):
         max_persisted_id = yield self.main_store.get_current_events_token()
         return (event.internal_metadata.stream_ordering, max_persisted_id)
 
-    def _maybe_start_persisting(self, room_id):
-        @defer.inlineCallbacks
-        def persisting_queue(item):
+    def _maybe_start_persisting(self, room_id: str):
+        async def persisting_queue(item):
             with Measure(self._clock, "persist_events"):
-                yield self._persist_events(
+                await self._persist_events(
                     item.events_and_contexts, backfilled=item.backfilled
                 )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
-    @defer.inlineCallbacks
-    def _persist_events(self, events_and_contexts, backfilled=False):
+    async def _persist_events(
+        self,
+        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        backfilled: bool = False,
+    ):
         """Calculates the change to current state and forward extremities, and
         persists the given events and with those updates.
-
-        Args:
-            events_and_contexts (list[(EventBase, EventContext)]):
-            backfilled (bool):
-            delete_existing (bool):
-
-        Returns:
-            Deferred: resolves when the events have been persisted
         """
         if not events_and_contexts:
             return
@@ -315,10 +326,10 @@ class EventsPersistenceStorage(object):
                         )
 
                     for room_id, ev_ctx_rm in iteritems(events_by_room):
-                        latest_event_ids = yield self.main_store.get_latest_event_ids_in_room(
+                        latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
                             room_id
                         )
-                        new_latest_event_ids = yield self._calculate_new_extremities(
+                        new_latest_event_ids = await self._calculate_new_extremities(
                             room_id, ev_ctx_rm, latest_event_ids
                         )
 
@@ -374,7 +385,7 @@ class EventsPersistenceStorage(object):
                         with Measure(
                             self._clock, "persist_events.get_new_state_after_events"
                         ):
-                            res = yield self._get_new_state_after_events(
+                            res = await self._get_new_state_after_events(
                                 room_id,
                                 ev_ctx_rm,
                                 latest_event_ids,
@@ -389,12 +400,12 @@ class EventsPersistenceStorage(object):
                             # If there is a delta we know that we've
                             # only added or replaced state, never
                             # removed keys entirely.
-                            state_delta_for_room[room_id] = ([], delta_ids)
+                            state_delta_for_room[room_id] = DeltaState([], delta_ids)
                         elif current_state is not None:
                             with Measure(
                                 self._clock, "persist_events.calculate_state_delta"
                             ):
-                                delta = yield self._calculate_state_delta(
+                                delta = await self._calculate_state_delta(
                                     room_id, current_state
                                 )
                             state_delta_for_room[room_id] = delta
@@ -404,7 +415,7 @@ class EventsPersistenceStorage(object):
                         if current_state is not None:
                             current_state_for_room[room_id] = current_state
 
-            yield self.main_store._persist_events_and_state_updates(
+            await self.main_store._persist_events_and_state_updates(
                 chunk,
                 current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
@@ -412,8 +423,12 @@ class EventsPersistenceStorage(object):
                 backfilled=backfilled,
             )
 
-    @defer.inlineCallbacks
-    def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
+    async def _calculate_new_extremities(
+        self,
+        room_id: str,
+        event_contexts: List[Tuple[FrozenEvent, EventContext]],
+        latest_event_ids: List[str],
+    ):
         """Calculates the new forward extremities for a room given events to
         persist.
 
@@ -444,13 +459,13 @@ class EventsPersistenceStorage(object):
         )
 
         # Remove any events which are prev_events of any existing events.
-        existing_prevs = yield self.main_store._get_events_which_are_prevs(result)
+        existing_prevs = await self.main_store._get_events_which_are_prevs(result)
         result.difference_update(existing_prevs)
 
         # Finally handle the case where the new events have soft-failed prev
         # events. If they do we need to remove them and their prev events,
         # otherwise we end up with dangling extremities.
-        existing_prevs = yield self.main_store._get_prevs_before_rejected(
+        existing_prevs = await self.main_store._get_prevs_before_rejected(
             e_id for event in new_events for e_id in event.prev_event_ids()
         )
         result.difference_update(existing_prevs)
@@ -464,10 +479,13 @@ class EventsPersistenceStorage(object):
 
         return result
 
-    @defer.inlineCallbacks
-    def _get_new_state_after_events(
-        self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
-    ):
+    async def _get_new_state_after_events(
+        self,
+        room_id: str,
+        events_context: List[Tuple[FrozenEvent, EventContext]],
+        old_latest_event_ids: Iterable[str],
+        new_latest_event_ids: Iterable[str],
+    ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
         """Calculate the current state dict after adding some new events to
         a room
 
@@ -485,7 +503,6 @@ class EventsPersistenceStorage(object):
                 the new forward extremities for the room.
 
         Returns:
-            Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
             Returns a tuple of two state maps, the first being the full new current
             state and the second being the delta to the existing current state.
             If both are None then there has been no change.
@@ -547,7 +564,7 @@ class EventsPersistenceStorage(object):
 
         if missing_event_ids:
             # Now pull out the state groups for any missing events from DB
-            event_to_groups = yield self.main_store._get_state_group_for_events(
+            event_to_groups = await self.main_store._get_state_group_for_events(
                 missing_event_ids
             )
             event_id_to_state_group.update(event_to_groups)
@@ -588,7 +605,7 @@ class EventsPersistenceStorage(object):
         # their state IDs so we can resolve to a single state set.
         missing_state = new_state_groups - set(state_groups_map)
         if missing_state:
-            group_to_state = yield self.state_store._get_state_for_groups(missing_state)
+            group_to_state = await self.state_store._get_state_for_groups(missing_state)
             state_groups_map.update(group_to_state)
 
         if len(new_state_groups) == 1:
@@ -612,10 +629,10 @@ class EventsPersistenceStorage(object):
                 break
 
         if not room_version:
-            room_version = yield self.main_store.get_room_version(room_id)
+            room_version = await self.main_store.get_room_version(room_id)
 
         logger.debug("calling resolve_state_groups from preserve_events")
-        res = yield self._state_resolution_handler.resolve_state_groups(
+        res = await self._state_resolution_handler.resolve_state_groups(
             room_id,
             room_version,
             state_groups,
@@ -625,18 +642,14 @@ class EventsPersistenceStorage(object):
 
         return res.state, None
 
-    @defer.inlineCallbacks
-    def _calculate_state_delta(self, room_id, current_state):
+    async def _calculate_state_delta(
+        self, room_id: str, current_state: StateMap[str]
+    ) -> DeltaState:
         """Calculate the new state deltas for a room.
 
         Assumes that we are only persisting events for one room at a time.
-
-        Returns:
-            tuple[list, dict] (to_delete, to_insert): where to_delete are the
-            type/state_keys to remove from current_state_events and `to_insert`
-            are the updates to current_state_events.
         """
-        existing_state = yield self.main_store.get_current_state_ids(room_id)
+        existing_state = await self.main_store.get_current_state_ids(room_id)
 
         to_delete = [key for key in existing_state if key not in current_state]
 
@@ -646,4 +659,4 @@ class EventsPersistenceStorage(object):
             if ev_id != existing_state.get(key)
         }
 
-        return to_delete, to_insert
+        return DeltaState(to_delete=to_delete, to_insert=to_insert)