summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorDirk Klimpel <5740567+dklimpel@users.noreply.github.com>2021-12-29 14:04:28 +0100
committerGitHub <noreply@github.com>2021-12-29 13:04:28 +0000
commitf82d38ed2e8e07c81a0d60f31401e3edecd5e57f (patch)
tree0dd17fb84983341a160d7ed6d189f52ef63006bb /synapse/storage
parentDo not attempt to bundled aggregations for /members and /state. (#11623) (diff)
downloadsynapse-f82d38ed2e8e07c81a0d60f31401e3edecd5e57f.tar.xz
Improve type hints in storage classes. (#11652)
By using cast and making ignores more specific.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/event_federation.py6
-rw-r--r--synapse/storage/databases/main/event_push_actions.py18
-rw-r--r--synapse/storage/databases/main/filtering.py4
-rw-r--r--synapse/storage/databases/main/media_repository.py3
-rw-r--r--synapse/storage/databases/main/pusher.py6
-rw-r--r--synapse/storage/databases/main/registration.py17
-rw-r--r--synapse/storage/databases/main/relations.py4
-rw-r--r--synapse/storage/databases/main/ui_auth.py15
9 files changed, 43 insertions, 34 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index b410eefdc7..3682cb6a81 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -673,7 +673,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
                 # There's a type mismatch here between how we want to type the row and
                 # what fetchone says it returns, but we silence it because we know that
                 # res can't be None.
-                res: Tuple[Optional[int]] = txn.fetchone()  # type: ignore[assignment]
+                res = cast(Tuple[Optional[int]], txn.fetchone())
                 if res[0] is None:
                     # this can only happen if the `device_inbox` table is empty, in which
                     # case we have no work to do.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index bc5ff25d08..270b30800b 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -288,7 +288,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             new_front = set()
             for chunk in batch_iter(front, 100):
                 # Pull the auth events either from the cache or DB.
-                to_fetch = []  # Event IDs to fetch from DB  # type: List[str]
+                to_fetch: List[str] = []  # Event IDs to fetch from DB
                 for event_id in chunk:
                     res = self._event_auth_cache.get(event_id)
                     if res is None:
@@ -615,8 +615,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             # currently walking, either from cache or DB.
             search, chunk = search[:-100], search[-100:]
 
-            found = []  # Results found  # type: List[Tuple[str, str, int]]
-            to_fetch = []  # Event IDs to fetch from DB  # type: List[str]
+            found: List[Tuple[str, str, int]] = []  # Results found
+            to_fetch: List[str] = []  # Event IDs to fetch from DB
             for _, event_id in chunk:
                 res = self._event_auth_cache.get(event_id)
                 if res is None:
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 98ea0e884c..a98e6b2593 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
 
 import attr
 
@@ -326,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()  # type: ignore[return-value]
+            return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
         after_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
@@ -357,7 +357,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()  # type: ignore[return-value]
+            return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
         no_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
@@ -434,7 +434,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()  # type: ignore[return-value]
+            return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
         after_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
@@ -465,7 +465,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()  # type: ignore[return-value]
+            return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
         no_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
@@ -662,7 +662,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             The stream ordering
         """
         txn.execute("SELECT MAX(stream_ordering) FROM events")
-        max_stream_ordering = txn.fetchone()[0]  # type: ignore[index]
+        max_stream_ordering = cast(Tuple[Optional[int]], txn.fetchone())[0]
 
         if max_stream_ordering is None:
             return 0
@@ -731,7 +731,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 " LIMIT 1"
             )
             txn.execute(sql, (stream_ordering,))
-            return txn.fetchone()  # type: ignore[return-value]
+            return cast(Optional[Tuple[int]], txn.fetchone())
 
         result = await self.db_pool.runInteraction(
             "get_time_of_last_push_action_before", f
@@ -1029,7 +1029,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
                 " LIMIT ?" % (before_clause,)
             )
             txn.execute(sql, args)
-            return txn.fetchall()  # type: ignore[return-value]
+            return cast(
+                List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall()
+            )
 
         push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
         return [
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index cf842803bc..cb9ee08fa8 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Union
+from typing import Optional, Tuple, Union, cast
 
 from canonicaljson import encode_canonical_json
 
@@ -63,7 +63,7 @@ class FilteringStore(SQLBaseStore):
 
             sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
             txn.execute(sql, (user_localpart,))
-            max_id = txn.fetchone()[0]  # type: ignore[index]
+            max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
             if max_id is None:
                 filter_id = 0
             else:
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 1b076683f7..cbba356b4a 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -23,6 +23,7 @@ from typing import (
     Optional,
     Tuple,
     Union,
+    cast,
 )
 
 from synapse.storage._base import SQLBaseStore
@@ -220,7 +221,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 WHERE user_id = ?
             """
             txn.execute(sql, args)
-            count = txn.fetchone()[0]  # type: ignore[index]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             sql = """
                 SELECT
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 7ab681ed6f..747b4f31df 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -494,7 +494,7 @@ class PusherStore(PusherWorkerStore):
                 # invalidate, since we the user might not have had a pusher before
                 await self.db_pool.runInteraction(
                     "add_pusher",
-                    self._invalidate_cache_and_stream,  # type: ignore
+                    self._invalidate_cache_and_stream,  # type: ignore[attr-defined]
                     self.get_if_user_has_pusher,
                     (user_id,),
                 )
@@ -503,7 +503,7 @@ class PusherStore(PusherWorkerStore):
         self, app_id: str, pushkey: str, user_id: str
     ) -> None:
         def delete_pusher_txn(txn, stream_id):
-            self._invalidate_cache_and_stream(  # type: ignore
+            self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
@@ -548,7 +548,7 @@ class PusherStore(PusherWorkerStore):
         pushers = list(await self.get_pushers_by_user_id(user_id))
 
         def delete_pushers_txn(txn, stream_ids):
-            self._invalidate_cache_and_stream(  # type: ignore
+            self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 29d9d4de96..4175c82a25 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
 
 import attr
 
@@ -1357,12 +1357,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             # Override type because the return type is only optional if
             # allow_none is True, and we don't want mypy throwing errors
             # about None not being indexable.
-            res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
-                txn,
-                "registration_tokens",
-                keyvalues={"token": token},
-                retcols=["pending", "completed"],
-            )  # type: ignore
+            res = cast(
+                Dict[str, Any],
+                self.db_pool.simple_select_one_txn(
+                    txn,
+                    "registration_tokens",
+                    keyvalues={"token": token},
+                    retcols=["pending", "completed"],
+                ),
+            )
 
             # Decrement pending and increment completed
             self.db_pool.simple_update_one_txn(
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 729ff17e2e..4ff6aed253 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union, cast
 
 import attr
 
@@ -399,7 +399,7 @@ class RelationsWorkerStore(SQLBaseStore):
                     AND relation_type = ?
             """
             txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
-            count = txn.fetchone()[0]  # type: ignore[index]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             return count, latest_event_id
 
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 340ca9e47d..a1a1a6a14a 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, cast
 
 import attr
 
@@ -225,11 +225,14 @@ class UIAuthWorkerStore(SQLBaseStore):
         self, txn: LoggingTransaction, session_id: str, key: str, value: Any
     ):
         # Get the current value.
-        result: Dict[str, Any] = self.db_pool.simple_select_one_txn(  # type: ignore
-            txn,
-            table="ui_auth_sessions",
-            keyvalues={"session_id": session_id},
-            retcols=("serverdict",),
+        result = cast(
+            Dict[str, Any],
+            self.db_pool.simple_select_one_txn(
+                txn,
+                table="ui_auth_sessions",
+                keyvalues={"session_id": session_id},
+                retcols=("serverdict",),
+            ),
         )
 
         # Update it and add it back to the database.