diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 7ab370efef..78ca6d8346 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,6 +28,7 @@ from typing import (
Optional,
Tuple,
TypeVar,
+ cast,
overload,
)
@@ -35,7 +36,6 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
-from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -507,8 +507,9 @@ class DatabasePool(object):
self._txn_perf_counters.update(desc, duration)
sql_txn_timer.labels(desc).observe(duration)
- @defer.inlineCallbacks
- def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
+ async def runInteraction(
+ self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ ) -> R:
"""Starts a transaction on the database and runs a given function
Arguments:
@@ -521,7 +522,7 @@ class DatabasePool(object):
kwargs: named args to pass to `func`
Returns:
- Deferred: The result of func
+ The result of func
"""
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
@@ -530,16 +531,14 @@ class DatabasePool(object):
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
- result = yield defer.ensureDeferred(
- self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- **kwargs
- )
+ result = await self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ **kwargs
)
for after_callback, after_args, after_kwargs in after_callbacks:
@@ -549,7 +548,7 @@ class DatabasePool(object):
after_callback(*after_args, **after_kwargs)
raise
- return result
+ return cast(R, result)
async def runWithConnection(
self, func: "Callable[..., R]", *args: Any, **kwargs: Any
@@ -604,6 +603,18 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor]
return results
+ @overload
+ async def execute(
+ self, desc: str, decoder: Literal[None], query: str, *args: Any
+ ) -> List[Tuple[Any, ...]]:
+ ...
+
+ @overload
+ async def execute(
+ self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
+ ) -> R:
+ ...
+
async def execute(
self,
desc: str,
@@ -1088,6 +1099,28 @@ class DatabasePool(object):
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
+ @overload
+ async def simple_select_one_onecol(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: Literal[False] = False,
+ desc: str = "simple_select_one_onecol",
+ ) -> Any:
+ ...
+
+ @overload
+ async def simple_select_one_onecol(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: Literal[True] = True,
+ desc: str = "simple_select_one_onecol",
+ ) -> Optional[Any]:
+ ...
+
async def simple_select_one_onecol(
self,
table: str,
@@ -1116,6 +1149,30 @@ class DatabasePool(object):
allow_none=allow_none,
)
+ @overload
+ @classmethod
+ def simple_select_one_onecol_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: Literal[False] = False,
+ ) -> Any:
+ ...
+
+ @overload
+ @classmethod
+ def simple_select_one_onecol_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: Literal[True] = True,
+ ) -> Optional[Any]:
+ ...
+
@classmethod
def simple_select_one_onecol_txn(
cls,
|