diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index d03555a585..14294a0bb8 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -14,16 +14,19 @@
import calendar
import logging
import time
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, List, Tuple, cast
from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
-from synapse.storage.types import Cursor
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -73,7 +76,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
@wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self) -> None:
- def fetch(txn):
+ def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
txn.execute(
"""
SELECT t1.c, t2.c
@@ -86,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) t2 ON t1.room_id = t2.room_id
"""
)
- return txn.fetchall()
+ return cast(List[Tuple[int, int]], txn.fetchall())
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
@@ -104,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
call to this function, it will return None.
"""
- def _count_messages(txn):
+ def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
async def count_daily_sent_e2ee_messages(self) -> int:
- def _count_messages(txn):
+ def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
@@ -130,7 +133,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
@@ -138,14 +141,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
)
async def count_daily_active_e2ee_rooms(self) -> int:
- def _count(txn):
+ def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
@@ -160,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
call to this function, it will return None.
"""
- def _count_messages(txn):
+ def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_messages", _count_messages)
async def count_daily_sent_messages(self) -> int:
- def _count_messages(txn):
+ def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
@@ -186,7 +189,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
@@ -194,14 +197,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
)
async def count_daily_active_rooms(self) -> int:
- def _count(txn):
+ def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
@@ -227,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_monthly_users", self._count_users, thirty_days_ago
)
- def _count_users(self, txn: Cursor, time_from: int) -> int:
+ def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
@@ -242,7 +245,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
# Mypy knows that fetchone() might return None if there are no rows.
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
# returns exactly one row.
- (count,) = txn.fetchone() # type: ignore[misc]
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
async def count_r30_users(self) -> Dict[str, int]:
@@ -256,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
A mapping of counts globally as well as broken out by platform.
"""
- def _count_r30_users(txn):
+ def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
thirty_days_ago_in_secs = now - thirty_days_in_secs
@@ -321,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count
return results
@@ -348,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
- "web" (any web application -- it's not possible to distinguish Element Web here)
"""
- def _count_r30v2_users(txn):
+ def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
@@ -445,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
thirty_days_in_secs * 1000,
),
)
- row = txn.fetchone()
- if row is None:
- results["all"] = 0
- else:
- results["all"] = row[0]
+ (count,) = cast(Tuple[int], txn.fetchone())
+ results["all"] = count
return results
@@ -471,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
Generates daily visit data for use in cohort/ retention analysis
"""
- def _generate_user_daily_visits(txn):
+ def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
a_day_in_milliseconds = 24 * 60 * 60 * 1000
|