summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--mypy.ini2
-rw-r--r--synapse/storage/databases/main/events.py21
-rw-r--r--synapse/storage/persist_events.py16
3 files changed, 27 insertions, 12 deletions
diff --git a/mypy.ini b/mypy.ini
index 7764f17856..063416c5cc 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -46,10 +46,12 @@ files =
   synapse/server_notices,
   synapse/spam_checker_api,
   synapse/state,
+  synapse/storage/databases/main/events.py,
   synapse/storage/databases/main/stream.py,
   synapse/storage/databases/main/ui_auth.py,
   synapse/storage/database.py,
   synapse/storage/engines,
+  synapse/storage/persist_events.py,
   synapse/storage/state.py,
   synapse/storage/util,
   synapse/streams,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b3d27a2ee7..9cd1403b38 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -213,7 +213,7 @@ class PersistEventsStore:
         Returns:
             Filtered event ids
         """
-        results = []
+        results = []  # type: List[str]
 
         def _get_events_which_are_prevs_txn(txn, batch):
             sql = """
@@ -631,7 +631,9 @@ class PersistEventsStore:
         )
 
     @classmethod
-    def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+    def _filter_events_and_contexts_for_duplicates(
+        cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Ensure that we don't have the same event twice.
 
         Pick the earliest non-outlier if there is one, else the earliest one.
@@ -641,7 +643,9 @@ class PersistEventsStore:
         Returns:
             list[(EventBase, EventContext)]: filtered list
         """
-        new_events_and_contexts = OrderedDict()
+        new_events_and_contexts = (
+            OrderedDict()
+        )  # type: OrderedDict[str, Tuple[EventBase, EventContext]]
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
             if prev_event_context:
@@ -655,7 +659,12 @@ class PersistEventsStore:
                 new_events_and_contexts[event.event_id] = (event, context)
         return list(new_events_and_contexts.values())
 
-    def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+    def _update_room_depths_txn(
+        self,
+        txn,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        backfilled: bool,
+    ):
         """Update min_depth for each room
 
         Args:
@@ -664,7 +673,7 @@ class PersistEventsStore:
                 we are persisting
             backfilled (bool): True if the events were backfilled
         """
-        depth_updates = {}
+        depth_updates = {}  # type: Dict[str, int]
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -1436,7 +1445,7 @@ class PersistEventsStore:
 
         Forward extremities are handled when we first start persisting the events.
         """
-        events_by_room = {}
+        events_by_room = {}  # type: Dict[str, List[EventBase]]
         for ev in events:
             events_by_room.setdefault(ev.room_id, []).append(ev)
 
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index dbaeef91dd..d89f6ed128 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -18,7 +18,7 @@
 import itertools
 import logging
 from collections import deque, namedtuple
-from typing import Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Optional, Set, Tuple
 
 from prometheus_client import Counter, Histogram
 
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
-from synapse.types import StateMap
+from synapse.types import Collection, StateMap
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
 
@@ -185,6 +185,8 @@ class EventsPersistenceStorage:
         # store for now.
         self.main_store = stores.main
         self.state_store = stores.state
+
+        assert stores.persist_events
         self.persist_events_store = stores.persist_events
 
         self._clock = hs.get_clock()
@@ -208,7 +210,7 @@ class EventsPersistenceStorage:
         Returns:
             the stream ordering of the latest persisted event
         """
-        partitioned = {}
+        partitioned = {}  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
@@ -305,7 +307,9 @@ class EventsPersistenceStorage:
                     # Work out the new "current state" for each room.
                     # We do this by working out what the new extremities are and then
                     # calculating the state from that.
-                    events_by_room = {}
+                    events_by_room = (
+                        {}
+                    )  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
                     for event, context in chunk:
                         events_by_room.setdefault(event.room_id, []).append(
                             (event, context)
@@ -436,7 +440,7 @@ class EventsPersistenceStorage:
         self,
         room_id: str,
         event_contexts: List[Tuple[EventBase, EventContext]],
-        latest_event_ids: List[str],
+        latest_event_ids: Collection[str],
     ):
         """Calculates the new forward extremities for a room given events to
         persist.
@@ -470,7 +474,7 @@ class EventsPersistenceStorage:
         # Remove any events which are prev_events of any existing events.
         existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
             result
-        )
+        )  # type: Collection[str]
         result.difference_update(existing_prevs)
 
         # Finally handle the case where the new events have soft-failed prev