summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2023-04-14 19:04:49 +0100
committerGitHub <noreply@github.com>2023-04-14 18:04:49 +0000
commit8a47d6e3a685bd45237b7dae9c138209df509f64 (patch)
tree976ca29fd809e11be0c909f46059b9006cfc2fc5 /synapse/storage
parentDisable directory listing for `StaticResource` (#15438) (diff)
downloadsynapse-8a47d6e3a685bd45237b7dae9c138209df509f64.tar.xz
More precise type for LoggingTransaction.execute (#15432)
* More precise type for LoggingTransaction.execute
* Add an annotation for stream_ordering_month_ago

This would have spotted the error that was fixed in "Add comma missing from #15382. (#15429)"
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/database.py20
-rw-r--r--synapse/storage/databases/main/event_federation.py19
-rw-r--r--synapse/storage/types.py6
3 files changed, 31 insertions, 14 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 226ccc1671..1f5f5eb6f8 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -58,7 +58,7 @@ from synapse.metrics import register_threadpool
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
-from synapse.storage.types import Connection, Cursor
+from synapse.storage.types import Connection, Cursor, SQLQueryParameters
 from synapse.util.async_helpers import delay_cancellation
 from synapse.util.iterutils import batch_iter
 
@@ -371,10 +371,18 @@ class LoggingTransaction:
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch
 
+            # TODO: is it safe for values to be Iterable[Iterable[Any]] here?
+            # https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_batch
+            # suggests each arg in args should be a sequence or mapping
             self._do_execute(
                 lambda the_sql: execute_batch(self.txn, the_sql, args), sql
             )
         else:
+            # TODO: is it safe for values to be Iterable[Iterable[Any]] here?
+            # https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#sqlite3.Cursor.executemany
+            # suggests that the outer collection may be iterable, but
+            # https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#how-to-use-placeholders-to-bind-values-in-sql-queries
+            # suggests that the inner collection should be a sequence or dict.
             self.executemany(sql, args)
 
     def execute_values(
@@ -390,14 +398,20 @@ class LoggingTransaction:
         from psycopg2.extras import execute_values
 
         return self._do_execute(
+            # TODO: is it safe for values to be Iterable[Iterable[Any]] here?
+            # https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
             lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
             sql,
         )
 
-    def execute(self, sql: str, *args: Any) -> None:
-        self._do_execute(self.txn.execute, sql, *args)
+    def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
+        self._do_execute(self.txn.execute, sql, parameters)
 
     def executemany(self, sql: str, *args: Any) -> None:
+        # TODO: we should add a type for *args here. Looking at Cursor.executemany
+        # and DBAPI2 it ought to be Sequence[_Parameter], but we pass in
+        # Iterable[Iterable[Any]] in execute_batch and execute_values above, which mypy
+        # complains about.
         self._do_execute(self.txn.executemany, sql, *args)
 
     def executescript(self, sql: str) -> None:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 2ad6fa7d5e..ac19de183c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -114,6 +114,10 @@ class _NoChainCoverIndex(Exception):
 
 
 class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBaseStore):
+    # TODO: this attribute comes from EventPushActionWorkerStore. Should we inherit from
+    # that store so that mypy can deduce this for itself?
+    stream_ordering_month_ago: Optional[int]
+
     def __init__(
         self,
         database: DatabasePool,
@@ -1182,8 +1186,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         Throws a StoreError if we have since purged the index for
         stream_orderings from that point.
         """
-
-        if stream_ordering <= self.stream_ordering_month_ago:  # type: ignore[attr-defined]
+        assert self.stream_ordering_month_ago is not None
+        if stream_ordering <= self.stream_ordering_month_ago:
             raise StoreError(400, f"stream_ordering too old {stream_ordering}")
 
         sql = """
@@ -1231,7 +1235,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # provided the last_change is recent enough, we now clamp the requested
         # stream_ordering to it.
-        if last_change > self.stream_ordering_month_ago:  # type: ignore[attr-defined]
+        assert self.stream_ordering_month_ago is not None
+        if last_change > self.stream_ordering_month_ago:
             stream_ordering = min(last_change, stream_ordering)
 
         return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@@ -1246,8 +1251,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         Throws a StoreError if we have since purged the index for
         stream_orderings from that point.
         """
-
-        if stream_ordering <= self.stream_ordering_month_ago:  # type: ignore[attr-defined]
+        assert self.stream_ordering_month_ago is not None
+        if stream_ordering <= self.stream_ordering_month_ago:
             raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
 
         sql = """
@@ -1707,9 +1712,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 DELETE FROM stream_ordering_to_exterm
                 WHERE stream_ordering < ?
             """
-            txn.execute(
-                sql, (self.stream_ordering_month_ago,)  # type: ignore[attr-defined]
-            )
+            txn.execute(sql, (self.stream_ordering_month_ago,))
 
         await self.db_pool.runInteraction(
             "_delete_old_forward_extrem_cache",
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 56a0048539..34ac807530 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -31,14 +31,14 @@ from typing_extensions import Protocol
 Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
 """
 
-_Parameters = Union[Sequence[Any], Mapping[str, Any]]
+SQLQueryParameters = Union[Sequence[Any], Mapping[str, Any]]
 
 
 class Cursor(Protocol):
-    def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
+    def execute(self, sql: str, parameters: SQLQueryParameters = ...) -> Any:
         ...
 
-    def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
+    def executemany(self, sql: str, parameters: Sequence[SQLQueryParameters]) -> Any:
         ...
 
     def fetchone(self) -> Optional[Tuple]: