summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py43
1 files changed, 29 insertions, 14 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index c237d003bc..964078aed4 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import functools
 import logging
 from itertools import islice
 
@@ -67,9 +66,12 @@ class Clock(object):
             f(function): The function to call repeatedly.
             msec(float): How long to wait between calls in milliseconds.
         """
-        call = task.LoopingCall(_log_exception_wrapper(f))
+        call = task.LoopingCall(f)
         call.clock = self._reactor
-        call.start(msec / 1000.0, now=False)
+        d = call.start(msec / 1000.0, now=False)
+        d.addErrback(make_log_failure_errback(
+            "Looping call died", consumeErrors=False,
+        ))
         return call
 
     def call_later(self, delay, callback, *args, **kwargs):
@@ -112,17 +114,30 @@ def batch_iter(iterable, size):
     return iter(lambda: tuple(islice(sourceiter, size)), ())
 
 
-def _log_exception_wrapper(f):
-    """Used to wrap looping calls to log loudly if they get killed
+def make_log_failure_errback(msg, consumeErrors=True):
+    """Creates a function suitable for passing to `Deferred.addErrback` that
+    logs any failures that occur.
+
+    Args:
+        msg (str): Message to log
+        consumeErrors (bool): If true consumes the failure, otherwise passes
+            on down the callback chain
+
+    Returns:
+        func(Failure)
     """
 
-    @functools.wraps(f)
-    def wrap(*args, **kwargs):
-        try:
-            logger.info("Running looping call")
-            return f(*args, **kwargs)
-        except:  # noqa: E722, as we reraise the exception this is fine.
-            logger.exception("Looping called died")
-            raise
+    def log_failure(failure):
+        logger.error(
+            msg,
+            exc_info=(
+                failure.type,
+                failure.value,
+                failure.getTracebackObject()
+            )
+        )
+
+        if not consumeErrors:
+            return failure
 
-    return wrap
+    return log_failure