diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 8e5d78f6f7..bbff3c8d5b 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -47,6 +47,9 @@ class Storage:
# interfaces.
self.main = stores.main
- self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
+
+ self.persistence = None
+ if stores.persist_events:
+ self.persistence = EventsPersistenceStorage(hs, stores)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b3d27a2ee7..9cd1403b38 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -213,7 +213,7 @@ class PersistEventsStore:
Returns:
Filtered event ids
"""
- results = []
+ results = [] # type: List[str]
def _get_events_which_are_prevs_txn(txn, batch):
sql = """
@@ -631,7 +631,9 @@ class PersistEventsStore:
)
@classmethod
- def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+ def _filter_events_and_contexts_for_duplicates(
+ cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
+ ) -> List[Tuple[EventBase, EventContext]]:
"""Ensure that we don't have the same event twice.
Pick the earliest non-outlier if there is one, else the earliest one.
@@ -641,7 +643,9 @@ class PersistEventsStore:
Returns:
list[(EventBase, EventContext)]: filtered list
"""
- new_events_and_contexts = OrderedDict()
+ new_events_and_contexts = (
+ OrderedDict()
+ ) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
@@ -655,7 +659,12 @@ class PersistEventsStore:
new_events_and_contexts[event.event_id] = (event, context)
return list(new_events_and_contexts.values())
- def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+ def _update_room_depths_txn(
+ self,
+ txn,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ backfilled: bool,
+ ):
"""Update min_depth for each room
Args:
@@ -664,7 +673,7 @@ class PersistEventsStore:
we are persisting
backfilled (bool): True if the events were backfilled
"""
- depth_updates = {}
+ depth_updates = {} # type: Dict[str, int]
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -1436,7 +1445,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
- events_by_room = {}
+ events_by_room = {} # type: Dict[str, List[EventBase]]
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 08a13a8b47..2e95518752 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -310,11 +310,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_rooms(
self,
room_ids: Collection[str],
- from_key: str,
- to_key: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
- ) -> Dict[str, Tuple[List[EventBase], str]]:
+ ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -333,9 +333,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
"""
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
-
- room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
+ room_ids = self._events_stream_cache.get_entities_changed(
+ room_ids, from_key.stream
+ )
if not room_ids:
return {}
@@ -364,16 +364,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def get_rooms_that_changed(
- self, room_ids: Collection[str], from_key: str
+ self, room_ids: Collection[str], from_key: RoomStreamToken
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
-
- Args:
- room_ids
- from_key: The room_key portion of a StreamToken
"""
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ from_id = from_key.stream
return {
room_id
for room_id in room_ids
@@ -383,11 +379,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_room(
self,
room_id: str,
- from_key: str,
- to_key: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
- ) -> Tuple[List[EventBase], str]:
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -408,8 +404,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if from_key == to_key:
return [], from_key
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- to_id = RoomStreamToken.parse_stream_token(to_key).stream
+ from_id = from_key.stream
+ to_id = to_key.stream
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
@@ -441,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
- key = "s%d" % min(r.stream_ordering for r in rows)
+ key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -450,10 +446,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
- self, user_id: str, from_key: str, to_key: str
+ self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- to_id = RoomStreamToken.parse_stream_token(to_key).stream
+ from_id = from_key.stream
+ to_id = to_key.stream
if from_key == to_key:
return []
@@ -491,8 +487,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
async def get_recent_events_for_room(
- self, room_id: str, limit: int, end_token: str
- ) -> Tuple[List[EventBase], str]:
+ self, room_id: str, limit: int, end_token: RoomStreamToken
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@@ -518,8 +514,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
async def get_recent_event_ids_for_room(
- self, room_id: str, limit: int, end_token: str
- ) -> Tuple[List[_EventDictReturn], str]:
+ self, room_id: str, limit: int, end_token: RoomStreamToken
+ ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@@ -535,13 +531,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0:
return [], end_token
- parsed_end_token = RoomStreamToken.parse(end_token)
-
rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
- from_token=parsed_end_token,
+ from_token=end_token,
limit=limit,
)
@@ -619,17 +613,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
allow_none=allow_none,
)
- async def get_stream_token_for_event(self, event_id: str) -> str:
+ async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A "s%d" stream token.
+ A stream token.
"""
stream_id = await self.get_stream_id_for_event(event_id)
- return "s%d" % (stream_id,)
+ return RoomStreamToken(None, stream_id)
async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
@@ -954,7 +948,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
- ) -> Tuple[List[_EventDictReturn], str]:
+ ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@@ -1054,17 +1048,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
- return rows, str(next_token)
+ return rows, next_token
async def paginate_room_events(
self,
room_id: str,
- from_key: str,
- to_key: Optional[str] = None,
+ from_key: RoomStreamToken,
+ to_key: Optional[RoomStreamToken] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
- ) -> Tuple[List[EventBase], str]:
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@@ -1083,17 +1077,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`).
"""
- parsed_from_key = RoomStreamToken.parse(from_key)
- parsed_to_key = None
- if to_key:
- parsed_to_key = RoomStreamToken.parse(to_key)
-
rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
- parsed_from_key,
- parsed_to_key,
+ from_key,
+ to_key,
direction,
limit,
event_filter,
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index dbaeef91dd..d89f6ed128 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -18,7 +18,7 @@
import itertools
import logging
from collections import deque, namedtuple
-from typing import Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Histogram
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import StateMap
+from synapse.types import Collection, StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -185,6 +185,8 @@ class EventsPersistenceStorage:
# store for now.
self.main_store = stores.main
self.state_store = stores.state
+
+ assert stores.persist_events
self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
@@ -208,7 +210,7 @@ class EventsPersistenceStorage:
Returns:
the stream ordering of the latest persisted event
"""
- partitioned = {}
+ partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
@@ -305,7 +307,9 @@ class EventsPersistenceStorage:
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
- events_by_room = {}
+ events_by_room = (
+ {}
+ ) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
@@ -436,7 +440,7 @@ class EventsPersistenceStorage:
self,
room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]],
- latest_event_ids: List[str],
+ latest_event_ids: Collection[str],
):
"""Calculates the new forward extremities for a room given events to
persist.
@@ -470,7 +474,7 @@ class EventsPersistenceStorage:
# Remove any events which are prev_events of any existing events.
existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
result
- )
+ ) # type: Collection[str]
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev
|