summary refs log tree commit diff
path: root/synapse/storage/background_updates.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/background_updates.py')
-rw-r--r--synapse/storage/background_updates.py111
1 files changed, 64 insertions, 47 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 810721ebe9..29b8ca676a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,29 +12,34 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.types import Connection
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 
 from . import engines
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.storage.database import DatabasePool, LoggingTransaction
+
 logger = logging.getLogger(__name__)
 
 
 class BackgroundUpdatePerformance:
     """Tracks the how long a background update is taking to update its items"""
 
-    def __init__(self, name):
+    def __init__(self, name: str):
         self.name = name
         self.total_item_count = 0
-        self.total_duration_ms = 0
-        self.avg_item_count = 0
-        self.avg_duration_ms = 0
+        self.total_duration_ms = 0.0
+        self.avg_item_count = 0.0
+        self.avg_duration_ms = 0.0
 
-    def update(self, item_count, duration_ms):
+    def update(self, item_count: int, duration_ms: float) -> None:
         """Update the stats after doing an update"""
         self.total_item_count += item_count
         self.total_duration_ms += duration_ms
@@ -44,7 +49,7 @@ class BackgroundUpdatePerformance:
         self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
         self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
 
-    def average_items_per_ms(self):
+    def average_items_per_ms(self) -> Optional[float]:
         """An estimate of how long it takes to do a single update.
         Returns:
             A duration in ms as a float
@@ -58,7 +63,7 @@ class BackgroundUpdatePerformance:
             # changes in how long the update process takes.
             return float(self.avg_item_count) / float(self.avg_duration_ms)
 
-    def total_items_per_ms(self):
+    def total_items_per_ms(self) -> Optional[float]:
         """An estimate of how long it takes to do a single update.
         Returns:
             A duration in ms as a float
@@ -83,21 +88,25 @@ class BackgroundUpdater:
     BACKGROUND_UPDATE_INTERVAL_MS = 1000
     BACKGROUND_UPDATE_DURATION_MS = 100
 
-    def __init__(self, hs, database):
+    def __init__(self, hs: "HomeServer", database: "DatabasePool"):
         self._clock = hs.get_clock()
         self.db_pool = database
 
         # if a background update is currently running, its name.
         self._current_background_update = None  # type: Optional[str]
 
-        self._background_update_performance = {}
-        self._background_update_handlers = {}
+        self._background_update_performance = (
+            {}
+        )  # type: Dict[str, BackgroundUpdatePerformance]
+        self._background_update_handlers = (
+            {}
+        )  # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
         self._all_done = False
 
-    def start_doing_background_updates(self):
+    def start_doing_background_updates(self) -> None:
         run_as_background_process("background_updates", self.run_background_updates)
 
-    async def run_background_updates(self, sleep=True):
+    async def run_background_updates(self, sleep: bool = True) -> None:
         logger.info("Starting background schema updates")
         while True:
             if sleep:
@@ -148,7 +157,7 @@ class BackgroundUpdater:
 
         return False
 
-    async def has_completed_background_update(self, update_name) -> bool:
+    async def has_completed_background_update(self, update_name: str) -> bool:
         """Check if the given background update has finished running.
         """
         if self._all_done:
@@ -173,8 +182,7 @@ class BackgroundUpdater:
         Returns once some amount of work is done.
 
         Args:
-            desired_duration_ms(float): How long we want to spend
-                updating.
+            desired_duration_ms: How long we want to spend updating.
         Returns:
             True if we have finished running all the background updates, otherwise False
         """
@@ -220,6 +228,7 @@ class BackgroundUpdater:
         return False
 
     async def _do_background_update(self, desired_duration_ms: float) -> int:
+        assert self._current_background_update is not None
         update_name = self._current_background_update
         logger.info("Starting update batch on background update '%s'", update_name)
 
@@ -273,7 +282,11 @@ class BackgroundUpdater:
 
         return len(self._background_update_performance)
 
-    def register_background_update_handler(self, update_name, update_handler):
+    def register_background_update_handler(
+        self,
+        update_name: str,
+        update_handler: Callable[[JsonDict, int], Awaitable[int]],
+    ):
         """Register a handler for doing a background update.
 
         The handler should take two arguments:
@@ -287,12 +300,12 @@ class BackgroundUpdater:
         The handler is responsible for updating the progress of the update.
 
         Args:
-            update_name(str): The name of the update that this code handles.
-            update_handler(function): The function that does the update.
+            update_name: The name of the update that this code handles.
+            update_handler: The function that does the update.
         """
         self._background_update_handlers[update_name] = update_handler
 
-    def register_noop_background_update(self, update_name):
+    def register_noop_background_update(self, update_name: str) -> None:
         """Register a noop handler for a background update.
 
         This is useful when we previously did a background update, but no
@@ -302,10 +315,10 @@ class BackgroundUpdater:
         also be called to clear the update.
 
         Args:
-            update_name (str): Name of update
+            update_name: Name of update
         """
 
-        async def noop_update(progress, batch_size):
+        async def noop_update(progress: JsonDict, batch_size: int) -> int:
             await self._end_background_update(update_name)
             return 1
 
@@ -313,14 +326,14 @@ class BackgroundUpdater:
 
     def register_background_index_update(
         self,
-        update_name,
-        index_name,
-        table,
-        columns,
-        where_clause=None,
-        unique=False,
-        psql_only=False,
-    ):
+        update_name: str,
+        index_name: str,
+        table: str,
+        columns: Iterable[str],
+        where_clause: Optional[str] = None,
+        unique: bool = False,
+        psql_only: bool = False,
+    ) -> None:
         """Helper for store classes to do a background index addition
 
         To use:
@@ -332,19 +345,19 @@ class BackgroundUpdater:
         2. In the Store constructor, call this method
 
         Args:
-            update_name (str): update_name to register for
-            index_name (str): name of index to add
-            table (str): table to add index to
-            columns (list[str]): columns/expressions to include in index
-            unique (bool): true to make a UNIQUE index
+            update_name: update_name to register for
+            index_name: name of index to add
+            table: table to add index to
+            columns: columns/expressions to include in index
+            unique: true to make a UNIQUE index
             psql_only: true to only create this index on psql databases (useful
                 for virtual sqlite tables)
         """
 
-        def create_index_psql(conn):
+        def create_index_psql(conn: Connection) -> None:
             conn.rollback()
             # postgres insists on autocommit for the index
-            conn.set_session(autocommit=True)
+            conn.set_session(autocommit=True)  # type: ignore
 
             try:
                 c = conn.cursor()
@@ -371,9 +384,9 @@ class BackgroundUpdater:
                 logger.debug("[SQL] %s", sql)
                 c.execute(sql)
             finally:
-                conn.set_session(autocommit=False)
+                conn.set_session(autocommit=False)  # type: ignore
 
-        def create_index_sqlite(conn):
+        def create_index_sqlite(conn: Connection) -> None:
             # Sqlite doesn't support concurrent creation of indexes.
             #
             # We don't use partial indices on SQLite as it wasn't introduced
@@ -399,7 +412,7 @@ class BackgroundUpdater:
             c.execute(sql)
 
         if isinstance(self.db_pool.engine, engines.PostgresEngine):
-            runner = create_index_psql
+            runner = create_index_psql  # type: Optional[Callable[[Connection], None]]
         elif psql_only:
             runner = None
         else:
@@ -433,7 +446,9 @@ class BackgroundUpdater:
             "background_updates", keyvalues={"update_name": update_name}
         )
 
-    async def _background_update_progress(self, update_name: str, progress: dict):
+    async def _background_update_progress(
+        self, update_name: str, progress: dict
+    ) -> None:
         """Update the progress of a background update
 
         Args:
@@ -441,20 +456,22 @@ class BackgroundUpdater:
             progress: The progress of the update.
         """
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "background_update_progress",
             self._background_update_progress_txn,
             update_name,
             progress,
         )
 
-    def _background_update_progress_txn(self, txn, update_name, progress):
+    def _background_update_progress_txn(
+        self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
+    ) -> None:
         """Update the progress of a background update
 
         Args:
-            txn(cursor): The transaction.
-            update_name(str): The name of the background update task
-            progress(dict): The progress of the update.
+            txn: The transaction.
+            update_name: The name of the background update task
+            progress: The progress of the update.
         """
 
         progress_json = json_encoder.encode(progress)