diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index f159400a87..dbaeef91dd 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -20,21 +20,17 @@ import logging
from collections import deque, namedtuple
from typing import Iterable, List, Optional, Set, Tuple
-from six import iteritems
-from six.moves import range
-
from prometheus_client import Counter, Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.events import FrozenEvent
+from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
-from synapse.storage.data_stores import DataStores
-from synapse.storage.data_stores.main.events import DeltaState
+from synapse.storage.databases import Databases
+from synapse.storage.databases.main.events import DeltaState
from synapse.types import StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -73,7 +69,7 @@ stale_forward_extremities_counter = Histogram(
)
-class _EventPeristenceQueue(object):
+class _EventPeristenceQueue:
"""Queues up events so that they can be persisted in bulk with only one
concurrent transaction per room.
"""
@@ -176,14 +172,14 @@ class _EventPeristenceQueue(object):
pass
-class EventsPersistenceStorage(object):
+class EventsPersistenceStorage:
"""High level interface for handling persisting newly received events.
Takes care of batching up events by room, and calculating the necessary
current state and forward extremity changes.
"""
- def __init__(self, hs, stores: DataStores):
+ def __init__(self, hs, stores: Databases):
# We ultimately want to split out the state store from the main store,
# so we use separate variables here even though they point to the same
# store for now.
@@ -196,12 +192,11 @@ class EventsPersistenceStorage(object):
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
- @defer.inlineCallbacks
- def persist_events(
+ async def persist_events(
self,
- events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
- ):
+ ) -> int:
"""
Write events to the database
Args:
@@ -211,14 +206,14 @@ class EventsPersistenceStorage(object):
which might update the current state etc.
Returns:
- Deferred[int]: the stream ordering of the latest persisted event
+ the stream ordering of the latest persisted event
"""
partitioned = {}
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
- for room_id, evs_ctxs in iteritems(partitioned):
+ for room_id, evs_ctxs in partitioned.items():
d = self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
)
@@ -227,22 +222,19 @@ class EventsPersistenceStorage(object):
for room_id in partitioned:
self._maybe_start_persisting(room_id)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
- max_persisted_id = yield self.main_store.get_current_events_token()
+ return self.main_store.get_current_events_token()
- return max_persisted_id
-
- @defer.inlineCallbacks
- def persist_event(
- self, event: FrozenEvent, context: EventContext, backfilled: bool = False
- ):
+ async def persist_event(
+ self, event: EventBase, context: EventContext, backfilled: bool = False
+ ) -> Tuple[int, int]:
"""
Returns:
- Deferred[Tuple[int, int]]: the stream ordering of ``event``,
- and the stream ordering of the latest persisted event
+ The stream ordering of `event`, and the stream ordering of the
+ latest persisted event
"""
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled
@@ -250,9 +242,9 @@ class EventsPersistenceStorage(object):
self._maybe_start_persisting(event.room_id)
- yield make_deferred_yieldable(deferred)
+ await make_deferred_yieldable(deferred)
- max_persisted_id = yield self.main_store.get_current_events_token()
+ max_persisted_id = self.main_store.get_current_events_token()
return (event.internal_metadata.stream_ordering, max_persisted_id)
def _maybe_start_persisting(self, room_id: str):
@@ -266,7 +258,7 @@ class EventsPersistenceStorage(object):
async def _persist_events(
self,
- events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
):
"""Calculates the change to current state and forward extremities, and
@@ -319,7 +311,7 @@ class EventsPersistenceStorage(object):
(event, context)
)
- for room_id, ev_ctx_rm in iteritems(events_by_room):
+ for room_id, ev_ctx_rm in events_by_room.items():
latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
room_id
)
@@ -443,7 +435,7 @@ class EventsPersistenceStorage(object):
async def _calculate_new_extremities(
self,
room_id: str,
- event_contexts: List[Tuple[FrozenEvent, EventContext]],
+ event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str],
):
"""Calculates the new forward extremities for a room given events to
@@ -501,7 +493,7 @@ class EventsPersistenceStorage(object):
async def _get_new_state_after_events(
self,
room_id: str,
- events_context: List[Tuple[FrozenEvent, EventContext]],
+ events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Iterable[str],
new_latest_event_ids: Iterable[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
@@ -651,6 +643,10 @@ class EventsPersistenceStorage(object):
room_version = await self.main_store.get_room_version_id(room_id)
logger.debug("calling resolve_state_groups from preserve_events")
+
+ # Avoid a circular import.
+ from synapse.state import StateResolutionStore
+
res = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
@@ -674,7 +670,7 @@ class EventsPersistenceStorage(object):
to_insert = {
key: ev_id
- for key, ev_id in iteritems(current_state)
+ for key, ev_id in current_state.items()
if ev_id != existing_state.get(key)
}
@@ -683,7 +679,7 @@ class EventsPersistenceStorage(object):
async def _is_server_still_joined(
self,
room_id: str,
- ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
+ ev_ctx_rm: List[Tuple[EventBase, EventContext]],
delta: DeltaState,
current_state: Optional[StateMap[str]],
potentially_left_users: Set[str],
@@ -786,9 +782,3 @@ class EventsPersistenceStorage(object):
for user_id in left_users:
await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
-
- async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
- """Mark the invite has having been rejected even though we failed to
- create a leave event for it.
- """
- return await self.persist_events_store.locally_reject_invite(user_id, room_id)
|