summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py51
1 files changed, 33 insertions, 18 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 483dd80406..2df4dd4ed4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -25,6 +25,7 @@ from typing import (
     Optional,
     Set,
     Tuple,
+    cast,
 )
 
 from synapse.api.errors import Codes, StoreError
@@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore):
             Number of devices of this users.
         """
 
-        def count_devices_by_users_txn(txn, user_ids):
+        def count_devices_by_users_txn(
+            txn: LoggingTransaction, user_ids: List[str]
+        ) -> int:
             sql = """
                 SELECT count(*)
                 FROM devices
@@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore):
             )
 
             txn.execute(sql + clause, args)
-            return txn.fetchone()[0]
+            return cast(Tuple[int], txn.fetchone())[0]
 
         if not user_ids:
             return 0
@@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
 
-        return list(txn)
+        return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall())
 
     async def _get_device_update_edus_by_remote(
         self,
@@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore):
     async def _get_last_device_update_for_remote_user(
         self, destination: str, user_id: str, from_stream_id: int
     ) -> int:
-        def f(txn):
+        def f(txn: LoggingTransaction) -> int:
             prev_sent_id_sql = """
                 SELECT coalesce(max(stream_id), 0) as stream_id
                 FROM device_lists_outbound_last_success
@@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore):
         if not user_ids_to_check:
             return set()
 
-        def _get_users_whose_devices_changed_txn(txn):
+        def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
             changes = set()
 
             stream_id_where_clause = "stream_id > ?"
@@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore):
     async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
         """Mark that we no longer track device lists for remote user."""
 
-        def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
+        def _mark_remote_user_device_list_as_unsubscribed_txn(
+            txn: LoggingTransaction,
+        ) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="device_lists_remote_extremeties",
@@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     def _store_dehydrated_device_txn(
-        self, txn, user_id: str, device_id: str, device_data: str
+        self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
     ) -> Optional[str]:
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn,
@@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         yesterday = self._clock.time_msec() - prune_age
 
-        def _prune_txn(txn):
+        def _prune_txn(txn: LoggingTransaction) -> None:
             # look for (user, destination) pairs which have an update older than
             # the cutoff.
             #
@@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             "drop_device_lists_outbound_last_success_non_unique_idx",
         )
 
-    async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
-        def f(conn):
+    async def _drop_device_list_streams_non_unique_indexes(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def f(conn: LoggingDatabaseConnection) -> None:
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
             txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
@@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
         )
         return 1
 
-    async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
+    async def _remove_duplicate_outbound_pokes(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         # for some reason, we have accumulated duplicate entries in
         # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
         # efficient.
@@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
         )
 
-        def _txn(txn):
+        def _txn(txn: LoggingTransaction) -> int:
             clause, args = make_tuple_comparison_clause(
                 [(x, last_row[x]) for x in KEY_COLS]
             )
@@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         context = get_active_span_text_map()
 
-        def add_device_changes_txn(txn, stream_ids):
+        def add_device_changes_txn(
+            txn: LoggingTransaction, stream_ids: List[int]
+        ) -> None:
             self._add_device_change_to_stream_txn(
                 txn,
                 user_id,
@@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         txn: LoggingTransaction,
         user_id: str,
         device_ids: Collection[str],
-        stream_ids: List[str],
-    ):
+        stream_ids: List[int],
+    ) -> None:
         txn.call_after(
             self._device_list_stream_cache.entity_has_changed,
             user_id,
@@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_ids: Iterable[str],
         room_ids: Collection[str],
-        stream_ids: List[str],
+        stream_ids: List[int],
         context: Dict[str, str],
     ) -> None:
         """Record the user in the room has updated their device."""
@@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             LIMIT ?
         """
 
-        def get_uncoverted_outbound_room_pokes_txn(txn):
+        def get_uncoverted_outbound_room_pokes_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
             txn.execute(sql, (limit,))
 
             return [
@@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Marks the associated row in `device_lists_changes_in_room` as handled.
         """
 
-        def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+        def add_device_list_outbound_pokes_txn(
+            txn: LoggingTransaction, stream_ids: List[int]
+        ) -> None:
             if hosts:
                 self._add_device_outbound_poke_to_stream_txn(
                     txn,