summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12666.misc1
-rw-r--r--pyproject.toml2
-rw-r--r--synapse/storage/database.py19
3 files changed, 16 insertions, 6 deletions
diff --git a/changelog.d/12666.misc b/changelog.d/12666.misc
new file mode 100644
index 0000000000..96268e33f5
--- /dev/null
+++ b/changelog.d/12666.misc
@@ -0,0 +1 @@
+Use `Concatenate` to better annotate `_do_execute`.
diff --git a/pyproject.toml b/pyproject.toml
index 7348230fba..4c51b8c4a1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -142,7 +142,7 @@ netaddr = ">=0.7.18"
 # add a lower bound to the Jinja2 dependency.
 Jinja2 = ">=3.0"
 bleach = ">=1.4.3"
-# We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0.
+# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
 typing-extensions = ">=3.10.0"
 # We enforce that we have a `cryptography` version that bundles an `openssl`
 # with the latest security patches.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index df1e9c1b83..2255e55f6f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -38,7 +38,7 @@ from typing import (
 
 import attr
 from prometheus_client import Histogram
-from typing_extensions import Literal
+from typing_extensions import Concatenate, Literal, ParamSpec
 
 from twisted.enterprise import adbapi
 
@@ -194,7 +194,7 @@ class LoggingDatabaseConnection:
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 _CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
 
-
+P = ParamSpec("P")
 R = TypeVar("R")
 
 
@@ -339,7 +339,13 @@ class LoggingTransaction:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
+    def _do_execute(
+        self,
+        func: Callable[Concatenate[str, P], R],
+        sql: str,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> R:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -348,7 +354,10 @@ class LoggingTransaction:
         sql = self.database_engine.convert_param_style(sql)
         if args:
             try:
-                sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
+                # The type-ignore should be redundant once mypy releases a version with
+                # https://github.com/python/mypy/pull/12668. (`args` might be empty,
+                # (but we'll catch the index error if so.)
+                sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])  # type: ignore[index]
             except Exception:
                 # Don't let logging failures stop SQL from working
                 pass
@@ -363,7 +372,7 @@ class LoggingTransaction:
                     opentracing.tags.DATABASE_STATEMENT: sql,
                 },
             ):
-                return func(sql, *args)
+                return func(sql, *args, **kwargs)
         except Exception as e:
             sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
             raise