summary refs log tree commit diff
path: root/synapse/storage/databases/main/metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/metrics.py')
-rw-r--r--synapse/storage/databases/main/metrics.py56
1 files changed, 28 insertions, 28 deletions
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index d03555a585..14294a0bb8 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -14,16 +14,19 @@
 import calendar
 import logging
 import time
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, List, Tuple, cast
 
 from synapse.metrics import GaugeBucketCollector
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
-from synapse.storage.types import Cursor
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -73,7 +76,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
     @wrap_as_background_process("read_forward_extremities")
     async def _read_forward_extremities(self) -> None:
-        def fetch(txn):
+        def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
             txn.execute(
                 """
                 SELECT t1.c, t2.c
@@ -86,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
                 ) t2 ON t1.room_id = t2.room_id
                 """
             )
-            return txn.fetchall()
+            return cast(List[Tuple[int, int]], txn.fetchall())
 
         res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
 
@@ -104,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         call to this function, it will return None.
         """
 
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
 
     async def count_daily_sent_e2ee_messages(self) -> int:
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
             like_clause = "%:" + self.hs.hostname
@@ -130,7 +133,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             """
 
             txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
@@ -138,14 +141,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         )
 
     async def count_daily_active_e2ee_rooms(self) -> int:
-        def _count(txn):
+        def _count(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
@@ -160,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         call to this function, it will return None.
         """
 
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_messages", _count_messages)
 
     async def count_daily_sent_messages(self) -> int:
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
             like_clause = "%:" + self.hs.hostname
@@ -186,7 +189,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             """
 
             txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
@@ -194,14 +197,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         )
 
     async def count_daily_active_rooms(self) -> int:
-        def _count(txn):
+        def _count(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
@@ -227,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
-    def _count_users(self, txn: Cursor, time_from: int) -> int:
+    def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
         """
         Returns number of users seen in the past time_from period
         """
@@ -242,7 +245,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         # Mypy knows that fetchone() might return None if there are no rows.
         # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
         # returns exactly one row.
-        (count,) = txn.fetchone()  # type: ignore[misc]
+        (count,) = cast(Tuple[int], txn.fetchone())
         return count
 
     async def count_r30_users(self) -> Dict[str, int]:
@@ -256,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
              A mapping of counts globally as well as broken out by platform.
         """
 
-        def _count_r30_users(txn):
+        def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
             thirty_days_in_secs = 86400 * 30
             now = int(self._clock.time())
             thirty_days_ago_in_secs = now - thirty_days_in_secs
@@ -321,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
             txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
 
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             results["all"] = count
 
             return results
@@ -348,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
               - "web" (any web application -- it's not possible to distinguish Element Web here)
         """
 
-        def _count_r30v2_users(txn):
+        def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
             thirty_days_in_secs = 86400 * 30
             now = int(self._clock.time())
             sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
@@ -445,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
                     thirty_days_in_secs * 1000,
                 ),
             )
-            row = txn.fetchone()
-            if row is None:
-                results["all"] = 0
-            else:
-                results["all"] = row[0]
+            (count,) = cast(Tuple[int], txn.fetchone())
+            results["all"] = count
 
             return results
 
@@ -471,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         Generates daily visit data for use in cohort/ retention analysis
         """
 
-        def _generate_user_daily_visits(txn):
+        def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
             logger.info("Calling _generate_user_daily_visits")
             today_start = self._get_start_of_day()
             a_day_in_milliseconds = 24 * 60 * 60 * 1000