summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12311.misc1
-rw-r--r--synapse/storage/database.py17
2 files changed, 8 insertions, 10 deletions
diff --git a/changelog.d/12311.misc b/changelog.d/12311.misc
new file mode 100644
index 0000000000..df0e824a7e
--- /dev/null
+++ b/changelog.d/12311.misc
@@ -0,0 +1 @@
+Improve type annotations for `execute_values`.
\ No newline at end of file
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 367709a1a7..72fef1533f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -290,11 +290,15 @@ class LoggingTransaction:
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch
 
-            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+            self._do_execute(
+                lambda the_sql: execute_batch(self.txn, the_sql, args), sql
+            )
         else:
             self.executemany(sql, args)
 
-    def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]:
+    def execute_values(
+        self, sql: str, values: Iterable[Iterable[Any]], fetch: bool = True
+    ) -> List[Tuple]:
         """Corresponds to psycopg2.extras.execute_values. Only available when
         using postgres.
 
@@ -305,15 +309,8 @@ class LoggingTransaction:
         from psycopg2.extras import execute_values
 
         return self._do_execute(
-            # Type ignore: mypy is unhappy because if `x` is a 5-tuple, then there will
-            # be two values for `fetch`: one given positionally, and another given
-            # as a keyword argument. We might be able to fix this by
-            # - propagating the signature of psycopg2.extras.execute_values to this
-            #   function, or
-            # - changing `*args: Any` to `values: T` for some appropriate T.
-            lambda *x: execute_values(self.txn, *x, fetch=fetch),  # type: ignore[misc]
+            lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
             sql,
-            *args,
         )
 
     def execute(self, sql: str, *args: Any) -> None: