diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index dbaeef91dd..72939f3984 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -18,7 +18,7 @@
import itertools
import logging
from collections import deque, namedtuple
-from typing import Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Histogram
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import StateMap
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -185,18 +185,21 @@ class EventsPersistenceStorage:
# store for now.
self.main_store = stores.main
self.state_store = stores.state
+
+ assert stores.persist_events
self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
async def persist_events(
self,
- events_and_contexts: List[Tuple[EventBase, EventContext]],
+ events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
backfilled: bool = False,
- ) -> int:
+ ) -> RoomStreamToken:
"""
Write events to the database
Args:
@@ -208,7 +211,7 @@ class EventsPersistenceStorage:
Returns:
the stream ordering of the latest persisted event
"""
- partitioned = {}
+ partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
@@ -226,11 +229,11 @@ class EventsPersistenceStorage:
defer.gatherResults(deferreds, consumeErrors=True)
)
- return self.main_store.get_current_events_token()
+ return self.main_store.get_room_max_token()
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
- ) -> Tuple[int, int]:
+ ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
"""
Returns:
The stream ordering of `event`, and the stream ordering of the
@@ -244,8 +247,10 @@ class EventsPersistenceStorage:
await make_deferred_yieldable(deferred)
- max_persisted_id = self.main_store.get_current_events_token()
- return (event.internal_metadata.stream_ordering, max_persisted_id)
+ event_stream_id = event.internal_metadata.stream_ordering
+
+ pos = PersistedEventPosition(self._instance_name, event_stream_id)
+ return pos, self.main_store.get_room_max_token()
def _maybe_start_persisting(self, room_id: str):
async def persisting_queue(item):
@@ -305,7 +310,9 @@ class EventsPersistenceStorage:
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
- events_by_room = {}
+ events_by_room = (
+ {}
+ ) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
@@ -436,7 +443,7 @@ class EventsPersistenceStorage:
self,
room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]],
- latest_event_ids: List[str],
+ latest_event_ids: Collection[str],
):
"""Calculates the new forward extremities for a room given events to
persist.
@@ -470,7 +477,7 @@ class EventsPersistenceStorage:
# Remove any events which are prev_events of any existing events.
existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
result
- )
+ ) # type: Collection[str]
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev
|