summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/database.py21
1 files changed, 16 insertions, 5 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 5efe31aa19..fec4ae5b97 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -34,6 +34,7 @@ from typing import (
     Tuple,
     Type,
     TypeVar,
+    Union,
     cast,
     overload,
 )
@@ -100,6 +101,15 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 }
 
 
+class _PoolConnection(Connection):
+    """
+    A Connection from twisted.enterprise.adbapi.Connection.
+    """
+
+    def reconnect(self) -> None:
+        ...
+
+
 def make_pool(
     reactor: IReactorCore,
     db_config: DatabaseConnectionConfig,
@@ -856,7 +866,8 @@ class DatabasePool:
             try:
                 with opentracing.start_active_span(f"db.{desc}"):
                     result = await self.runWithConnection(
-                        self.new_transaction,
+                        # mypy seems to have an issue with this, maybe a bug?
+                        self.new_transaction,  # type: ignore[arg-type]
                         desc,
                         after_callbacks,
                         async_after_callbacks,
@@ -892,7 +903,7 @@ class DatabasePool:
 
     async def runWithConnection(
         self,
-        func: Callable[..., R],
+        func: Callable[Concatenate[LoggingDatabaseConnection, P], R],
         *args: Any,
         db_autocommit: bool = False,
         isolation_level: Optional[int] = None,
@@ -926,7 +937,7 @@ class DatabasePool:
 
         start_time = monotonic_time()
 
-        def inner_func(conn, *args, **kwargs):
+        def inner_func(conn: _PoolConnection, *args: P.args, **kwargs: P.kwargs) -> R:
             # We shouldn't be in a transaction. If we are then something
             # somewhere hasn't committed after doing work. (This is likely only
             # possible during startup, as `run*` will ensure changes are
@@ -1019,7 +1030,7 @@ class DatabasePool:
         decoder: Optional[Callable[[Cursor], R]],
         query: str,
         *args: Any,
-    ) -> R:
+    ) -> Union[List[Tuple[Any, ...]], R]:
         """Runs a single query for a result set.
 
         Args:
@@ -1032,7 +1043,7 @@ class DatabasePool:
             The result of decoder(results)
         """
 
-        def interaction(txn):
+        def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]:
             txn.execute(query, args)
             if decoder:
                 return decoder(txn)