diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 483dd80406..2df4dd4ed4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -25,6 +25,7 @@ from typing import (
Optional,
Set,
Tuple,
+ cast,
)
from synapse.api.errors import Codes, StoreError
@@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore):
Number of devices of this users.
"""
- def count_devices_by_users_txn(txn, user_ids):
+ def count_devices_by_users_txn(
+ txn: LoggingTransaction, user_ids: List[str]
+ ) -> int:
sql = """
SELECT count(*)
FROM devices
@@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
txn.execute(sql + clause, args)
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
if not user_ids:
return 0
@@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
- return list(txn)
+ return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall())
async def _get_device_update_edus_by_remote(
self,
@@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore):
async def _get_last_device_update_for_remote_user(
self, destination: str, user_id: str, from_stream_id: int
) -> int:
- def f(txn):
+ def f(txn: LoggingTransaction) -> int:
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
@@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore):
if not user_ids_to_check:
return set()
- def _get_users_whose_devices_changed_txn(txn):
+ def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
changes = set()
stream_id_where_clause = "stream_id > ?"
@@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore):
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user."""
- def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
+ def _mark_remote_user_device_list_as_unsubscribed_txn(
+ txn: LoggingTransaction,
+ ) -> None:
self.db_pool.simple_delete_txn(
txn,
table="device_lists_remote_extremeties",
@@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
def _store_dehydrated_device_txn(
- self, txn, user_id: str, device_id: str, device_data: str
+ self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
) -> Optional[str]:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
@@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
yesterday = self._clock.time_msec() - prune_age
- def _prune_txn(txn):
+ def _prune_txn(txn: LoggingTransaction) -> None:
# look for (user, destination) pairs which have an update older than
# the cutoff.
#
@@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
"drop_device_lists_outbound_last_success_non_unique_idx",
)
- async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
- def f(conn):
+ async def _drop_device_list_streams_non_unique_indexes(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ def f(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
@@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
return 1
- async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
+ async def _remove_duplicate_outbound_pokes(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
# for some reason, we have accumulated duplicate entries in
# device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
# efficient.
@@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
)
- def _txn(txn):
+ def _txn(txn: LoggingTransaction) -> int:
clause, args = make_tuple_comparison_clause(
[(x, last_row[x]) for x in KEY_COLS]
)
@@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context = get_active_span_text_map()
- def add_device_changes_txn(txn, stream_ids):
+ def add_device_changes_txn(
+ txn: LoggingTransaction, stream_ids: List[int]
+ ) -> None:
self._add_device_change_to_stream_txn(
txn,
user_id,
@@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn: LoggingTransaction,
user_id: str,
device_ids: Collection[str],
- stream_ids: List[str],
- ):
+ stream_ids: List[int],
+ ) -> None:
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id,
@@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_ids: Iterable[str],
room_ids: Collection[str],
- stream_ids: List[str],
+ stream_ids: List[int],
context: Dict[str, str],
) -> None:
"""Record the user in the room has updated their device."""
@@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
LIMIT ?
"""
- def get_uncoverted_outbound_room_pokes_txn(txn):
+ def get_uncoverted_outbound_room_pokes_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
txn.execute(sql, (limit,))
return [
@@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Marks the associated row in `device_lists_changes_in_room` as handled.
"""
- def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+ def add_device_list_outbound_pokes_txn(
+ txn: LoggingTransaction, stream_ids: List[int]
+ ) -> None:
if hosts:
self._add_device_outbound_poke_to_stream_txn(
txn,
|