summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9299.misc1
-rw-r--r--synapse/storage/database.py14
-rw-r--r--synapse/storage/prepare_database.py4
-rw-r--r--synapse/storage/types.py37
-rw-r--r--synapse/storage/util/sequence.py8
5 files changed, 47 insertions, 17 deletions
diff --git a/changelog.d/9299.misc b/changelog.d/9299.misc
new file mode 100644
index 0000000000..c883a677ed
--- /dev/null
+++ b/changelog.d/9299.misc
@@ -0,0 +1 @@
+Update the `Cursor` type hints to better match PEP 249.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d2ba4bd2fc..ae4bf1a54f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -158,8 +158,8 @@ class LoggingDatabaseConnection:
     def commit(self) -> None:
         self.conn.commit()
 
-    def rollback(self, *args, **kwargs) -> None:
-        self.conn.rollback(*args, **kwargs)
+    def rollback(self) -> None:
+        self.conn.rollback()
 
     def __enter__(self) -> "Connection":
         self.conn.__enter__()
@@ -244,12 +244,15 @@ class LoggingTransaction:
         assert self.exception_callbacks is not None
         self.exception_callbacks.append((callback, args, kwargs))
 
+    def fetchone(self) -> Optional[Tuple]:
+        return self.txn.fetchone()
+
+    def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
+        return self.txn.fetchmany(size=size)
+
     def fetchall(self) -> List[Tuple]:
         return self.txn.fetchall()
 
-    def fetchone(self) -> Tuple:
-        return self.txn.fetchone()
-
     def __iter__(self) -> Iterator[Tuple]:
         return self.txn.__iter__()
 
@@ -754,6 +757,7 @@ class DatabasePool:
         Returns:
             A list of dicts where the key is the column header.
         """
+        assert cursor.description is not None, "cursor.description was None"
         col_headers = [intern(str(column[0])) for column in cursor.description]
         results = [dict(zip(col_headers, row)) for row in cursor]
         return results
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 566ea19bae..28bb2eb662 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -619,9 +619,9 @@ def _get_or_create_schema_state(
 
     txn.execute("SELECT version, upgraded FROM schema_version")
     row = txn.fetchone()
-    current_version = int(row[0]) if row else None
 
-    if current_version:
+    if row is not None:
+        current_version = int(row[0])
         txn.execute(
             "SELECT file FROM applied_schema_deltas WHERE version >= ?",
             (current_version,),
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 9cadcba18f..17291c9d5e 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Iterable, Iterator, List, Optional, Tuple
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
 
 from typing_extensions import Protocol
 
@@ -20,23 +20,44 @@ from typing_extensions import Protocol
 Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
 """
 
+_Parameters = Union[Sequence[Any], Mapping[str, Any]]
+
 
 class Cursor(Protocol):
-    def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+    def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
         ...
 
-    def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+    def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
         ...
 
-    def fetchall(self) -> List[Tuple]:
+    def fetchone(self) -> Optional[Tuple]:
+        ...
+
+    def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
         ...
 
-    def fetchone(self) -> Tuple:
+    def fetchall(self) -> List[Tuple]:
         ...
 
     @property
-    def description(self) -> Any:
-        return None
+    def description(
+        self,
+    ) -> Optional[
+        Sequence[
+            # Note that this is an approximate typing based on sqlite3 and other
+            # drivers, and may not be entirely accurate.
+            Tuple[
+                str,
+                Optional[Any],
+                Optional[int],
+                Optional[int],
+                Optional[int],
+                Optional[int],
+                Optional[int],
+            ]
+        ]
+    ]:
+        ...
 
     @property
     def rowcount(self) -> int:
@@ -59,7 +80,7 @@ class Connection(Protocol):
     def commit(self) -> None:
         ...
 
-    def rollback(self, *args, **kwargs) -> None:
+    def rollback(self) -> None:
         ...
 
     def __enter__(self) -> "Connection":
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 0ec4dc2918..e2b316a218 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
 
     def get_next_id_txn(self, txn: Cursor) -> int:
         txn.execute("SELECT nextval(?)", (self._sequence_name,))
-        return txn.fetchone()[0]
+        fetch_res = txn.fetchone()
+        assert fetch_res is not None
+        return fetch_res[0]
 
     def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
         txn.execute(
@@ -147,7 +149,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
         txn.execute(
             "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
         )
-        last_value, is_called = txn.fetchone()
+        fetch_res = txn.fetchone()
+        assert fetch_res is not None
+        last_value, is_called = fetch_res
 
         # If we have an associated stream check the stream_positions table.
         max_in_stream_positions = None