summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py29
1 files changed, 14 insertions, 15 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index af8796ad92..8851710d47 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,6 +28,7 @@ from typing import (
     Optional,
     Tuple,
     TypeVar,
+    cast,
     overload,
 )
 
@@ -35,7 +36,6 @@ from prometheus_client import Histogram
 from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
-from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -507,8 +507,9 @@ class DatabasePool(object):
             self._txn_perf_counters.update(desc, duration)
             sql_txn_timer.labels(desc).observe(duration)
 
-    @defer.inlineCallbacks
-    def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
+    async def runInteraction(
+        self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+    ) -> R:
         """Starts a transaction on the database and runs a given function
 
         Arguments:
@@ -521,7 +522,7 @@ class DatabasePool(object):
             kwargs: named args to pass to `func`
 
         Returns:
-            Deferred: The result of func
+            The result of func
         """
         after_callbacks = []  # type: List[_CallbackListEntry]
         exception_callbacks = []  # type: List[_CallbackListEntry]
@@ -530,16 +531,14 @@ class DatabasePool(object):
             logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
-            result = yield defer.ensureDeferred(
-                self.runWithConnection(
-                    self.new_transaction,
-                    desc,
-                    after_callbacks,
-                    exception_callbacks,
-                    func,
-                    *args,
-                    **kwargs
-                )
+            result = await self.runWithConnection(
+                self.new_transaction,
+                desc,
+                after_callbacks,
+                exception_callbacks,
+                func,
+                *args,
+                **kwargs
             )
 
             for after_callback, after_args, after_kwargs in after_callbacks:
@@ -549,7 +548,7 @@ class DatabasePool(object):
                 after_callback(*after_args, **after_kwargs)
             raise
 
-        return result
+        return cast(R, result)
 
     async def runWithConnection(
         self, func: "Callable[..., R]", *args: Any, **kwargs: Any