summary refs log tree commit diff
diff options
context:
space:
mode:
authorDirk Klimpel <5740567+dklimpel@users.noreply.github.com>2021-12-29 14:01:13 +0100
committerGitHub <noreply@github.com>2021-12-29 08:01:13 -0500
commit15bb1c8511c13197a75df93f6a8021ec5f9586e6 (patch)
treeb46aee0b1ff8fdeec2a06d434239deb5946c2344
parentUpdate to the current version of Black and run it on Synapse codebase (#11596) (diff)
downloadsynapse-15bb1c8511c13197a75df93f6a8021ec5f9586e6.tar.xz
Add type hints to `synapse/storage/databases/main/stats.py` (#11653)
-rw-r--r--changelog.d/11653.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/storage/databases/main/stats.py94
3 files changed, 57 insertions, 42 deletions
diff --git a/changelog.d/11653.misc b/changelog.d/11653.misc
new file mode 100644
index 0000000000..8e405b9226
--- /dev/null
+++ b/changelog.d/11653.misc
@@ -0,0 +1 @@
+Add missing type hints to storage classes.
\ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
index 57e1a5df43..724c7e2ae4 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -39,7 +39,6 @@ exclude = (?x)
    |synapse/storage/databases/main/roommember.py
    |synapse/storage/databases/main/search.py
    |synapse/storage/databases/main/state.py
-   |synapse/storage/databases/main/stats.py
    |synapse/storage/databases/main/user_directory.py
    |synapse/storage/schema/
 
@@ -214,6 +213,9 @@ disallow_untyped_defs = True
 [mypy-synapse.storage.databases.main.profile]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.stats]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.state_deltas]
 disallow_untyped_defs = True
 
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index a0472e37f5..427ae1f649 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
 import logging
 from enum import Enum
 from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
 
 from typing_extensions import Counter
 
@@ -24,7 +24,11 @@ from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.api.errors import StoreError
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
@@ -122,7 +126,9 @@ class StatsStore(StateDeltasStore):
         self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
         self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
 
-    async def _populate_stats_process_users(self, progress, batch_size):
+    async def _populate_stats_process_users(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """
         This is a background update which regenerates statistics for users.
         """
@@ -134,7 +140,7 @@ class StatsStore(StateDeltasStore):
 
         last_user_id = progress.get("last_user_id", "")
 
-        def _get_next_batch(txn):
+        def _get_next_batch(txn: LoggingTransaction) -> List[str]:
             sql = """
                     SELECT DISTINCT name FROM users
                     WHERE name > ?
@@ -168,7 +174,9 @@ class StatsStore(StateDeltasStore):
 
         return len(users_to_work_on)
 
-    async def _populate_stats_process_rooms(self, progress, batch_size):
+    async def _populate_stats_process_rooms(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """This is a background update which regenerates statistics for rooms."""
         if not self.stats_enabled:
             await self.db_pool.updates._end_background_update(
@@ -178,7 +186,7 @@ class StatsStore(StateDeltasStore):
 
         last_room_id = progress.get("last_room_id", "")
 
-        def _get_next_batch(txn):
+        def _get_next_batch(txn: LoggingTransaction) -> List[str]:
             sql = """
                     SELECT DISTINCT room_id FROM current_state_events
                     WHERE room_id > ?
@@ -307,7 +315,7 @@ class StatsStore(StateDeltasStore):
             stream_id: Current position.
         """
 
-        def _bulk_update_stats_delta_txn(txn):
+        def _bulk_update_stats_delta_txn(txn: LoggingTransaction) -> None:
             for stats_type, stats_updates in updates.items():
                 for stats_id, fields in stats_updates.items():
                     logger.debug(
@@ -339,7 +347,7 @@ class StatsStore(StateDeltasStore):
         stats_type: str,
         stats_id: str,
         fields: Dict[str, int],
-        complete_with_stream_id: Optional[int],
+        complete_with_stream_id: int,
         absolute_field_overrides: Optional[Dict[str, int]] = None,
     ) -> None:
         """
@@ -372,14 +380,14 @@ class StatsStore(StateDeltasStore):
 
     def _update_stats_delta_txn(
         self,
-        txn,
-        ts,
-        stats_type,
-        stats_id,
-        fields,
-        complete_with_stream_id,
-        absolute_field_overrides=None,
-    ):
+        txn: LoggingTransaction,
+        ts: int,
+        stats_type: str,
+        stats_id: str,
+        fields: Dict[str, int],
+        complete_with_stream_id: int,
+        absolute_field_overrides: Optional[Dict[str, int]] = None,
+    ) -> None:
         if absolute_field_overrides is None:
             absolute_field_overrides = {}
 
@@ -422,20 +430,23 @@ class StatsStore(StateDeltasStore):
         )
 
     def _upsert_with_additive_relatives_txn(
-        self, txn, table, keyvalues, absolutes, additive_relatives
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        absolutes: Dict[str, Any],
+        additive_relatives: Dict[str, int],
+    ) -> None:
         """Used to update values in the stats tables.
 
         This is basically a slightly convoluted upsert that *adds* to any
         existing rows.
 
         Args:
-            txn
-            table (str): Table name
-            keyvalues (dict[str, any]): Row-identifying key values
-            absolutes (dict[str, any]): Absolute (set) fields
-            additive_relatives (dict[str, int]): Fields that will be added onto
-                if existing row present.
+            table: Table name
+            keyvalues: Row-identifying key values
+            absolutes: Absolute (set) fields
+            additive_relatives: Fields that will be added onto if existing row present.
         """
         if self.database_engine.can_native_upsert:
             absolute_updates = [
@@ -491,20 +502,17 @@ class StatsStore(StateDeltasStore):
                 current_row.update(absolutes)
                 self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
 
-    async def _calculate_and_set_initial_state_for_room(
-        self, room_id: str
-    ) -> Tuple[dict, dict, int]:
+    async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None:
         """Calculate and insert an entry into room_stats_current.
 
         Args:
             room_id: The room ID under calculation.
-
-        Returns:
-            A tuple of room state, membership counts and stream position.
         """
 
-        def _fetch_current_state_stats(txn):
-            pos = self.get_room_max_stream_ordering()
+        def _fetch_current_state_stats(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
+            pos = self.get_room_max_stream_ordering()  # type: ignore[attr-defined]
 
             rows = self.db_pool.simple_select_many_txn(
                 txn,
@@ -524,7 +532,7 @@ class StatsStore(StateDeltasStore):
                 retcols=["event_id"],
             )
 
-            event_ids = [row["event_id"] for row in rows]
+            event_ids = cast(List[str], [row["event_id"] for row in rows])
 
             txn.execute(
                 """
@@ -544,9 +552,9 @@ class StatsStore(StateDeltasStore):
                 (room_id,),
             )
 
-            (current_state_events_count,) = txn.fetchone()
+            current_state_events_count = cast(Tuple[int], txn.fetchone())[0]
 
-            users_in_room = self.get_users_in_room_txn(txn, room_id)
+            users_in_room = self.get_users_in_room_txn(txn, room_id)  # type: ignore[attr-defined]
 
             return (
                 event_ids,
@@ -566,7 +574,7 @@ class StatsStore(StateDeltasStore):
             "get_initial_state_for_room", _fetch_current_state_stats
         )
 
-        state_event_map = await self.get_events(event_ids, get_prev_content=False)
+        state_event_map = await self.get_events(event_ids, get_prev_content=False)  # type: ignore[attr-defined]
 
         room_state = {
             "join_rules": None,
@@ -622,8 +630,10 @@ class StatsStore(StateDeltasStore):
             },
         )
 
-    async def _calculate_and_set_initial_state_for_user(self, user_id):
-        def _calculate_and_set_initial_state_for_user_txn(txn):
+    async def _calculate_and_set_initial_state_for_user(self, user_id: str) -> None:
+        def _calculate_and_set_initial_state_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[int, int]:
             pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
 
             txn.execute(
@@ -634,7 +644,7 @@ class StatsStore(StateDeltasStore):
                 """,
                 (user_id,),
             )
-            (count,) = txn.fetchone()
+            count = cast(Tuple[int], txn.fetchone())[0]
             return count, pos
 
         joined_rooms, pos = await self.db_pool.runInteraction(
@@ -678,7 +688,9 @@ class StatsStore(StateDeltasStore):
             users that exist given this query
         """
 
-        def get_users_media_usage_paginate_txn(txn):
+        def get_users_media_usage_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
             filters = []
             args = [self.hs.config.server.server_name]
 
@@ -733,7 +745,7 @@ class StatsStore(StateDeltasStore):
                 sql_base=sql_base,
             )
             txn.execute(sql, args)
-            count = txn.fetchone()[0]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             sql = """
                 SELECT