summary refs log tree commit diff
path: root/synapse/metrics/background_process_metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/metrics/background_process_metrics.py')
-rw-r--r--synapse/metrics/background_process_metrics.py16
1 files changed, 5 insertions, 11 deletions
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 658f6ecd72..70e0fa45d9 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import inspect
 import logging
 import threading
 from functools import wraps
@@ -25,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.logging.opentracing import noop_context_manager, start_active_span
+from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     import resource
@@ -199,19 +199,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
         _background_process_start_count.labels(desc).inc()
         _background_process_in_flight_count.labels(desc).inc()
 
-        with BackgroundProcessLoggingContext(desc) as context:
-            context.request = "%s-%i" % (desc, count)
+        with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
             try:
                 ctx = noop_context_manager()
                 if bg_start_span:
                     ctx = start_active_span(desc, tags={"request_id": context.request})
                 with ctx:
-                    result = func(*args, **kwargs)
-
-                    if inspect.isawaitable(result):
-                        result = await result
-
-                    return result
+                    return await maybe_awaitable(func(*args, **kwargs))
             except Exception:
                 logger.exception(
                     "Background process '%s' threw an exception", desc,
@@ -249,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext):
 
     __slots__ = ["_proc"]
 
-    def __init__(self, name: str):
-        super().__init__(name)
+    def __init__(self, name: str, request: Optional[str] = None):
+        super().__init__(name, request=request)
 
         self._proc = _BackgroundProcess(name, self)