summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py23
-rw-r--r--synapse/util/async_helpers.py41
-rw-r--r--synapse/util/caches/__init__.py6
-rw-r--r--synapse/util/caches/descriptors.py100
-rw-r--r--synapse/util/caches/dictionary_cache.py15
-rw-r--r--synapse/util/caches/expiringcache.py18
-rw-r--r--synapse/util/caches/lrucache.py14
-rw-r--r--synapse/util/caches/response_cache.py24
-rw-r--r--synapse/util/caches/stream_change_cache.py13
-rw-r--r--synapse/util/caches/treecache.py1
-rw-r--r--synapse/util/caches/ttlcache.py1
-rw-r--r--synapse/util/distributor.py29
-rw-r--r--synapse/util/frozenutils.py9
-rw-r--r--synapse/util/httpresourcetree.py8
-rw-r--r--synapse/util/jsonobject.py6
-rw-r--r--synapse/util/logcontext.py151
-rw-r--r--synapse/util/logformatter.py3
-rw-r--r--synapse/util/logutils.py51
-rw-r--r--synapse/util/manhole.py18
-rw-r--r--synapse/util/metrics.py32
-rw-r--r--synapse/util/module_loader.py6
-rw-r--r--synapse/util/msisdn.py6
-rw-r--r--synapse/util/ratelimitutils.py35
-rw-r--r--synapse/util/retryutils.py55
-rw-r--r--synapse/util/stringutils.py16
-rw-r--r--synapse/util/threepids.py10
-rw-r--r--synapse/util/versionstring.py61
-rw-r--r--synapse/util/wheel_timer.py4
28 files changed, 407 insertions, 349 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 0ae7e2ef3b..dcc747cac1 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -40,6 +40,7 @@ class Clock(object):
     Args:
         reactor: The Twisted reactor to use.
     """
+
     _reactor = attr.ib()
 
     @defer.inlineCallbacks
@@ -70,9 +71,7 @@ class Clock(object):
         call = task.LoopingCall(f)
         call.clock = self._reactor
         d = call.start(msec / 1000.0, now=False)
-        d.addErrback(
-            log_failure, "Looping call died", consumeErrors=False,
-        )
+        d.addErrback(log_failure, "Looping call died", consumeErrors=False)
         return call
 
     def call_later(self, delay, callback, *args, **kwargs):
@@ -84,6 +83,7 @@ class Clock(object):
             *args: Postional arguments to pass to function.
             **kwargs: Key arguments to pass to function.
         """
+
         def wrapped_callback(*args, **kwargs):
             with PreserveLoggingContext():
                 callback(*args, **kwargs)
@@ -129,12 +129,7 @@ def log_failure(failure, msg, consumeErrors=True):
     """
 
     logger.error(
-        msg,
-        exc_info=(
-            failure.type,
-            failure.value,
-            failure.getTracebackObject()
-        )
+        msg, exc_info=(failure.type, failure.value, failure.getTracebackObject())
     )
 
     if not consumeErrors:
@@ -152,12 +147,12 @@ def glob_to_regex(glob):
     Returns:
         re.RegexObject
     """
-    res = ''
+    res = ""
     for c in glob:
-        if c == '*':
-            res = res + '.*'
-        elif c == '?':
-            res = res + '.'
+        if c == "*":
+            res = res + ".*"
+        elif c == "?":
+            res = res + "."
         else:
             res = res + re.escape(c)
 
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 7253ba120f..7757b8708a 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -95,6 +95,7 @@ class ObservableDeferred(object):
             def remove(r):
                 self._observers.discard(d)
                 return r
+
             d.addBoth(remove)
 
             self._observers.add(d)
@@ -123,7 +124,9 @@ class ObservableDeferred(object):
 
     def __repr__(self):
         return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
-            id(self), self._result, self._deferred,
+            id(self),
+            self._result,
+            self._deferred,
         )
 
 
@@ -150,10 +153,12 @@ def concurrently_execute(func, args, limit):
         except StopIteration:
             pass
 
-    return logcontext.make_deferred_yieldable(defer.gatherResults([
-        run_in_background(_concurrently_execute_inner)
-        for _ in range(limit)
-    ], consumeErrors=True)).addErrback(unwrapFirstError)
+    return logcontext.make_deferred_yieldable(
+        defer.gatherResults(
+            [run_in_background(_concurrently_execute_inner) for _ in range(limit)],
+            consumeErrors=True,
+        )
+    ).addErrback(unwrapFirstError)
 
 
 def yieldable_gather_results(func, iter, *args, **kwargs):
@@ -169,10 +174,12 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
         Deferred[list]: Resolved when all functions have been invoked, or errors if
         one of the function calls fails.
     """
-    return logcontext.make_deferred_yieldable(defer.gatherResults([
-        run_in_background(func, item, *args, **kwargs)
-        for item in iter
-    ], consumeErrors=True)).addErrback(unwrapFirstError)
+    return logcontext.make_deferred_yieldable(
+        defer.gatherResults(
+            [run_in_background(func, item, *args, **kwargs) for item in iter],
+            consumeErrors=True,
+        )
+    ).addErrback(unwrapFirstError)
 
 
 class Linearizer(object):
@@ -185,6 +192,7 @@ class Linearizer(object):
             # do some work.
 
     """
+
     def __init__(self, name=None, max_count=1, clock=None):
         """
         Args:
@@ -197,6 +205,7 @@ class Linearizer(object):
 
         if not clock:
             from twisted.internet import reactor
+
             clock = Clock(reactor)
         self._clock = clock
         self.max_count = max_count
@@ -221,7 +230,7 @@ class Linearizer(object):
             res = self._await_lock(key)
         else:
             logger.debug(
-                "Acquired uncontended linearizer lock %r for key %r", self.name, key,
+                "Acquired uncontended linearizer lock %r for key %r", self.name, key
             )
             entry[0] += 1
             res = defer.succeed(None)
@@ -266,9 +275,7 @@ class Linearizer(object):
         """
         entry = self.key_to_defer[key]
 
-        logger.debug(
-            "Waiting to acquire linearizer lock %r for key %r", self.name, key,
-        )
+        logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
 
         new_defer = make_deferred_yieldable(defer.Deferred())
         entry[1][new_defer] = 1
@@ -293,14 +300,14 @@ class Linearizer(object):
             logger.info("defer %r got err %r", new_defer, e)
             if isinstance(e, CancelledError):
                 logger.debug(
-                    "Cancelling wait for linearizer lock %r for key %r",
-                    self.name, key,
+                    "Cancelling wait for linearizer lock %r for key %r", self.name, key
                 )
 
             else:
                 logger.warn(
                     "Unexpected exception waiting for linearizer lock %r for key %r",
-                    self.name, key,
+                    self.name,
+                    key,
                 )
 
             # we just have to take ourselves back out of the queue.
@@ -438,7 +445,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
 
         try:
             deferred.cancel()
-        except:   # noqa: E722, if we throw any exception it'll break time outs
+        except:  # noqa: E722, if we throw any exception it'll break time outs
             logger.exception("Canceller failed during timeout")
 
         if not new_d.called:
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index f37d5bec08..8271229015 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -104,8 +104,8 @@ def register_cache(cache_type, cache_name, cache):
 
 
 KNOWN_KEYS = {
-    key: key for key in
-    (
+    key: key
+    for key in (
         "auth_events",
         "content",
         "depth",
@@ -150,7 +150,7 @@ def intern_dict(dictionary):
 
 
 def _intern_known_values(key, value):
-    intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",)
+    intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
 
     if key in intern_keys:
         return intern_string(value)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 187510576a..d2f25063aa 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -40,9 +40,7 @@ _CacheSentinel = object()
 
 
 class CacheEntry(object):
-    __slots__ = [
-        "deferred", "callbacks", "invalidated"
-    ]
+    __slots__ = ["deferred", "callbacks", "invalidated"]
 
     def __init__(self, deferred, callbacks):
         self.deferred = deferred
@@ -73,7 +71,9 @@ class Cache(object):
         self._pending_deferred_cache = cache_type()
 
         self.cache = LruCache(
-            max_size=max_entries, keylen=keylen, cache_type=cache_type,
+            max_size=max_entries,
+            keylen=keylen,
+            cache_type=cache_type,
             size_callback=(lambda d: len(d)) if iterable else None,
             evicted_callback=self._on_evicted,
         )
@@ -133,10 +133,7 @@ class Cache(object):
     def set(self, key, value, callback=None):
         callbacks = [callback] if callback else []
         self.check_thread()
-        entry = CacheEntry(
-            deferred=value,
-            callbacks=callbacks,
-        )
+        entry = CacheEntry(deferred=value, callbacks=callbacks)
 
         existing_entry = self._pending_deferred_cache.pop(key, None)
         if existing_entry:
@@ -191,9 +188,7 @@ class Cache(object):
     def invalidate_many(self, key):
         self.check_thread()
         if not isinstance(key, tuple):
-            raise TypeError(
-                "The cache key must be a tuple not %r" % (type(key),)
-            )
+            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
         self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
@@ -244,29 +239,25 @@ class _CacheDescriptorBase(object):
             raise Exception(
                 "Not enough explicit positional arguments to key off for %r: "
                 "got %i args, but wanted %i. (@cached cannot key off *args or "
-                "**kwargs)"
-                % (orig.__name__, len(all_args), num_args)
+                "**kwargs)" % (orig.__name__, len(all_args), num_args)
             )
 
         self.num_args = num_args
 
         # list of the names of the args used as the cache key
-        self.arg_names = all_args[1:num_args + 1]
+        self.arg_names = all_args[1 : num_args + 1]
 
         # self.arg_defaults is a map of arg name to its default value for each
         # argument that has a default value
         if arg_spec.defaults:
-            self.arg_defaults = dict(zip(
-                all_args[-len(arg_spec.defaults):],
-                arg_spec.defaults
-            ))
+            self.arg_defaults = dict(
+                zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults)
+            )
         else:
             self.arg_defaults = {}
 
         if "cache_context" in self.arg_names:
-            raise Exception(
-                "cache_context arg cannot be included among the cache keys"
-            )
+            raise Exception("cache_context arg cannot be included among the cache keys")
 
         self.add_cache_context = cache_context
 
@@ -304,12 +295,24 @@ class CacheDescriptor(_CacheDescriptorBase):
             ``cache_context``) to use as cache keys. Defaults to all named
             args of the function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
-                 inlineCallbacks=False, cache_context=False, iterable=False):
+
+    def __init__(
+        self,
+        orig,
+        max_entries=1000,
+        num_args=None,
+        tree=False,
+        inlineCallbacks=False,
+        cache_context=False,
+        iterable=False,
+    ):
 
         super(CacheDescriptor, self).__init__(
-            orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
-            cache_context=cache_context)
+            orig,
+            num_args=num_args,
+            inlineCallbacks=inlineCallbacks,
+            cache_context=cache_context,
+        )
 
         max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
 
@@ -356,7 +359,9 @@ class CacheDescriptor(_CacheDescriptorBase):
                     return args[0]
                 else:
                     return self.arg_defaults[nm]
+
         else:
+
             def get_cache_key(args, kwargs):
                 return tuple(get_cache_key_gen(args, kwargs))
 
@@ -383,8 +388,7 @@ class CacheDescriptor(_CacheDescriptorBase):
 
             except KeyError:
                 ret = defer.maybeDeferred(
-                    logcontext.preserve_fn(self.function_to_call),
-                    obj, *args, **kwargs
+                    logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs
                 )
 
                 def onErr(f):
@@ -437,8 +441,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
     results.
     """
 
-    def __init__(self, orig, cached_method_name, list_name, num_args=None,
-                 inlineCallbacks=False):
+    def __init__(
+        self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
+    ):
         """
         Args:
             orig (function)
@@ -451,7 +456,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
                 be wrapped by defer.inlineCallbacks
         """
         super(CacheListDescriptor, self).__init__(
-            orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks
+        )
 
         self.list_name = list_name
 
@@ -463,7 +469,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
         if self.list_name not in self.arg_names:
             raise Exception(
                 "Couldn't see arguments %r for %r."
-                % (self.list_name, cached_method_name,)
+                % (self.list_name, cached_method_name)
             )
 
     def __get__(self, obj, objtype=None):
@@ -494,8 +500,10 @@ class CacheListDescriptor(_CacheDescriptorBase):
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
+
                 def arg_to_cache_key(arg):
                     return arg
+
             else:
                 keylist = list(keyargs)
 
@@ -505,8 +513,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
             for arg in list_args:
                 try:
-                    res = cache.get(arg_to_cache_key(arg),
-                                    callback=invalidate_callback)
+                    res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
                     if not isinstance(res, ObservableDeferred):
                         results[arg] = res
                     elif not res.has_succeeded():
@@ -554,18 +561,15 @@ class CacheListDescriptor(_CacheDescriptorBase):
                 args_to_call = dict(arg_dict)
                 args_to_call[self.list_name] = list(missing)
 
-                cached_defers.append(defer.maybeDeferred(
-                    logcontext.preserve_fn(self.function_to_call),
-                    **args_to_call
-                ).addCallbacks(complete_all, errback))
+                cached_defers.append(
+                    defer.maybeDeferred(
+                        logcontext.preserve_fn(self.function_to_call), **args_to_call
+                    ).addCallbacks(complete_all, errback)
+                )
 
             if cached_defers:
-                d = defer.gatherResults(
-                    cached_defers,
-                    consumeErrors=True,
-                ).addCallbacks(
-                    lambda _: results,
-                    unwrapFirstError
+                d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
+                    lambda _: results, unwrapFirstError
                 )
                 return logcontext.make_deferred_yieldable(d)
             else:
@@ -586,8 +590,9 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
-           iterable=False):
+def cached(
+    max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -598,8 +603,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
-                          cache_context=False, iterable=False):
+def cachedInlineCallbacks(
+    max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 6c0b5a4094..6834e6f3ae 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -35,6 +35,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
             there.
         value (dict): The full or partial dict value
     """
+
     def __len__(self):
         return len(self.value)
 
@@ -84,13 +85,15 @@ class DictionaryCache(object):
             self.metrics.inc_hits()
 
             if dict_keys is None:
-                return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
+                return DictionaryEntry(
+                    entry.full, entry.known_absent, dict(entry.value)
+                )
             else:
-                return DictionaryEntry(entry.full, entry.known_absent, {
-                    k: entry.value[k]
-                    for k in dict_keys
-                    if k in entry.value
-                })
+                return DictionaryEntry(
+                    entry.full,
+                    entry.known_absent,
+                    {k: entry.value[k] for k in dict_keys if k in entry.value},
+                )
 
         self.metrics.inc_misses()
         return DictionaryEntry(False, set(), {})
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index f369780277..cddf1ed515 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -28,8 +28,15 @@ SENTINEL = object()
 
 
 class ExpiringCache(object):
-    def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
-                 reset_expiry_on_get=False, iterable=False):
+    def __init__(
+        self,
+        cache_name,
+        clock,
+        max_len=0,
+        expiry_ms=0,
+        reset_expiry_on_get=False,
+        iterable=False,
+    ):
         """
         Args:
             cache_name (str): Name of this cache, used for logging.
@@ -67,8 +74,7 @@ class ExpiringCache(object):
 
         def f():
             return run_as_background_process(
-                "prune_cache_%s" % self._cache_name,
-                self._prune_cache,
+                "prune_cache_%s" % self._cache_name, self._prune_cache
             )
 
         self._clock.looping_call(f, self._expiry_ms / 2)
@@ -153,7 +159,9 @@ class ExpiringCache(object):
 
         logger.debug(
             "[%s] _prune_cache before: %d, after len: %d",
-            self._cache_name, begin_length, len(self)
+            self._cache_name,
+            begin_length,
+            len(self),
         )
 
     def __len__(self):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index b684f24e7b..1536cb64f3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,8 +49,15 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
-                 evicted_callback=None):
+
+    def __init__(
+        self,
+        max_size,
+        keylen=1,
+        cache_type=dict,
+        size_callback=None,
+        evicted_callback=None,
+    ):
         """
         Args:
             max_size (int):
@@ -93,9 +100,12 @@ class LruCache(object):
 
         cached_cache_len = [0]
         if size_callback is not None:
+
             def cache_len():
                 return cached_cache_len[0]
+
         else:
+
             def cache_len():
                 return len(cache)
 
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index afb03b2e1b..cbe54d45dd 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -35,12 +35,10 @@ class ResponseCache(object):
         self.pending_result_cache = {}  # Requests that haven't finished yet.
 
         self.clock = hs.get_clock()
-        self.timeout_sec = timeout_ms / 1000.
+        self.timeout_sec = timeout_ms / 1000.0
 
         self._name = name
-        self._metrics = register_cache(
-            "response_cache", name, self
-        )
+        self._metrics = register_cache("response_cache", name, self)
 
     def size(self):
         return len(self.pending_result_cache)
@@ -100,8 +98,7 @@ class ResponseCache(object):
         def remove(r):
             if self.timeout_sec:
                 self.clock.call_later(
-                    self.timeout_sec,
-                    self.pending_result_cache.pop, key, None,
+                    self.timeout_sec, self.pending_result_cache.pop, key, None
                 )
             else:
                 self.pending_result_cache.pop(key, None)
@@ -140,21 +137,22 @@ class ResponseCache(object):
 
             *args: positional parameters to pass to the callback, if it is used
 
-            **kwargs: named paramters to pass to the callback, if it is used
+            **kwargs: named parameters to pass to the callback, if it is used
 
         Returns:
             twisted.internet.defer.Deferred: yieldable result
         """
         result = self.get(key)
         if not result:
-            logger.info("[%s]: no cached result for [%s], calculating new one",
-                        self._name, key)
+            logger.info(
+                "[%s]: no cached result for [%s], calculating new one", self._name, key
+            )
             d = run_in_background(callback, *args, **kwargs)
             result = self.set(key, d)
         elif not isinstance(result, defer.Deferred) or result.called:
-            logger.info("[%s]: using completed cached result for [%s]",
-                        self._name, key)
+            logger.info("[%s]: using completed cached result for [%s]", self._name, key)
         else:
-            logger.info("[%s]: using incomplete cached result for [%s]",
-                        self._name, key)
+            logger.info(
+                "[%s]: using incomplete cached result for [%s]", self._name, key
+            )
         return make_deferred_yieldable(result)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 625aedc940..235f64049c 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -77,9 +77,8 @@ class StreamChangeCache(object):
 
         if stream_pos >= self._earliest_known_stream_pos:
             changed_entities = {
-                self._cache[k] for k in self._cache.islice(
-                    start=self._cache.bisect_right(stream_pos),
-                )
+                self._cache[k]
+                for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
             }
 
             result = changed_entities.intersection(entities)
@@ -114,8 +113,10 @@ class StreamChangeCache(object):
         assert type(stream_pos) is int
 
         if stream_pos >= self._earliest_known_stream_pos:
-            return [self._cache[k] for k in self._cache.islice(
-                start=self._cache.bisect_right(stream_pos))]
+            return [
+                self._cache[k]
+                for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
+            ]
         else:
             return None
 
@@ -136,7 +137,7 @@ class StreamChangeCache(object):
             while len(self._cache) > self._max_size:
                 k, r = self._cache.popitem(0)
                 self._earliest_known_stream_pos = max(
-                    k, self._earliest_known_stream_pos,
+                    k, self._earliest_known_stream_pos
                 )
                 self._entity_to_key.pop(r, None)
 
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index dd4c9e6067..9a72218d85 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -9,6 +9,7 @@ class TreeCache(object):
     efficiently.
     Keys must be tuples.
     """
+
     def __init__(self):
         self.size = 0
         self.root = {}
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 5ba1862506..2af8ca43b1 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -155,6 +155,7 @@ class TTLCache(object):
 @attr.s(frozen=True, slots=True)
 class _CacheEntry(object):
     """TTLCache entry"""
+
     # expiry_time is the first attribute, so that entries are sorted by expiry.
     expiry_time = attr.ib()
     key = attr.ib()
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index e14c8bdfda..5a79db821c 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -51,9 +51,7 @@ class Distributor(object):
         if name in self.signals:
             raise KeyError("%r already has a signal named %s" % (self, name))
 
-        self.signals[name] = Signal(
-            name,
-        )
+        self.signals[name] = Signal(name)
 
         if name in self.pre_registration:
             signal = self.signals[name]
@@ -78,11 +76,7 @@ class Distributor(object):
         if name not in self.signals:
             raise KeyError("%r does not have a signal named %s" % (self, name))
 
-        run_as_background_process(
-            name,
-            self.signals[name].fire,
-            *args, **kwargs
-        )
+        run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
 
 
 class Signal(object):
@@ -118,22 +112,23 @@ class Signal(object):
             def eb(failure):
                 logger.warning(
                     "%s signal observer %s failed: %r",
-                    self.name, observer, failure,
+                    self.name,
+                    observer,
+                    failure,
                     exc_info=(
                         failure.type,
                         failure.value,
-                        failure.getTracebackObject()))
+                        failure.getTracebackObject(),
+                    ),
+                )
 
             return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
 
-        deferreds = [
-            run_in_background(do, o)
-            for o in self.observers
-        ]
+        deferreds = [run_in_background(do, o) for o in self.observers]
 
-        return make_deferred_yieldable(defer.gatherResults(
-            deferreds, consumeErrors=True,
-        ))
+        return make_deferred_yieldable(
+            defer.gatherResults(deferreds, consumeErrors=True)
+        )
 
     def __repr__(self):
         return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 014edea971..635b897d6c 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -60,11 +60,10 @@ def _handle_frozendict(obj):
         # fishing the protected dict out of the object is a bit nasty,
         # but we don't really want the overhead of copying the dict.
         return obj._dict
-    raise TypeError('Object of type %s is not JSON serializable' %
-                    obj.__class__.__name__)
+    raise TypeError(
+        "Object of type %s is not JSON serializable" % obj.__class__.__name__
+    )
 
 
 # A JSONEncoder which is capable of encoding frozendics without barfing
-frozendict_json_encoder = json.JSONEncoder(
-    default=_handle_frozendict,
-)
+frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict)
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 2d7ddc1cbe..1a20c596bf 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -45,7 +45,7 @@ def create_resource_tree(desired_tree, root_resource):
 
         logger.info("Attaching %s to path %s", res, full_path)
         last_resource = root_resource
-        for path_seg in full_path.split(b'/')[1:-1]:
+        for path_seg in full_path.split(b"/")[1:-1]:
             if path_seg not in last_resource.listNames():
                 # resource doesn't exist, so make a "dummy resource"
                 child_resource = NoResource()
@@ -60,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
 
         # ===========================
         # now attach the actual desired resource
-        last_path_seg = full_path.split(b'/')[-1]
+        last_path_seg = full_path.split(b"/")[-1]
 
         # if there is already a resource here, thieve its children and
         # replace it
@@ -70,9 +70,7 @@ def create_resource_tree(desired_tree, root_resource):
             # to be replaced with the desired resource.
             existing_dummy_resource = resource_mappings[res_id]
             for child_name in existing_dummy_resource.listNames():
-                child_res_id = _resource_id(
-                    existing_dummy_resource, child_name
-                )
+                child_res_id = _resource_id(existing_dummy_resource, child_name)
                 child_resource = resource_mappings[child_res_id]
                 # steal the children
                 res.putChild(child_name, child_resource)
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index d668e5a6b8..6dce03dd3a 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -70,7 +70,8 @@ class JsonEncodedObject(object):
             dict
         """
         d = {
-            k: _encode(v) for (k, v) in self.__dict__.items()
+            k: _encode(v)
+            for (k, v) in self.__dict__.items()
             if k in self.valid_keys and k not in self.internal_keys
         }
         d.update(self.unrecognized_keys)
@@ -78,7 +79,8 @@ class JsonEncodedObject(object):
 
     def get_internal_dict(self):
         d = {
-            k: _encode(v, internal=True) for (k, v) in self.__dict__.items()
+            k: _encode(v, internal=True)
+            for (k, v) in self.__dict__.items()
             if k in self.valid_keys
         }
         d.update(self.unrecognized_keys)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 311b49e18a..6b0d2deea0 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -42,6 +42,8 @@ try:
 
     def get_thread_resource_usage():
         return resource.getrusage(RUSAGE_THREAD)
+
+
 except Exception:
     # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
     # won't track resource usage by returning None.
@@ -64,8 +66,11 @@ class ContextResourceUsage(object):
     """
 
     __slots__ = [
-        "ru_stime", "ru_utime",
-        "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
+        "ru_stime",
+        "ru_utime",
+        "db_txn_count",
+        "db_txn_duration_sec",
+        "db_sched_duration_sec",
         "evt_db_fetch_count",
     ]
 
@@ -91,8 +96,8 @@ class ContextResourceUsage(object):
         return ContextResourceUsage(copy_from=self)
 
     def reset(self):
-        self.ru_stime = 0.
-        self.ru_utime = 0.
+        self.ru_stime = 0.0
+        self.ru_utime = 0.0
         self.db_txn_count = 0
 
         self.db_txn_duration_sec = 0
@@ -100,15 +105,18 @@ class ContextResourceUsage(object):
         self.evt_db_fetch_count = 0
 
     def __repr__(self):
-        return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
-                "db_txn_count='%r', db_txn_duration_sec='%r', "
-                "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % (
-                    self.ru_stime,
-                    self.ru_utime,
-                    self.db_txn_count,
-                    self.db_txn_duration_sec,
-                    self.db_sched_duration_sec,
-                    self.evt_db_fetch_count,)
+        return (
+            "<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
+            "db_txn_count='%r', db_txn_duration_sec='%r', "
+            "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>"
+        ) % (
+            self.ru_stime,
+            self.ru_utime,
+            self.db_txn_count,
+            self.db_txn_duration_sec,
+            self.db_sched_duration_sec,
+            self.evt_db_fetch_count,
+        )
 
     def __iadd__(self, other):
         """Add another ContextResourceUsage's stats to this one's.
@@ -159,11 +167,15 @@ class LoggingContext(object):
     """
 
     __slots__ = [
-        "previous_context", "name", "parent_context",
+        "previous_context",
+        "name",
+        "parent_context",
         "_resource_usage",
         "usage_start",
-        "main_thread", "alive",
-        "request", "tag",
+        "main_thread",
+        "alive",
+        "request",
+        "tag",
     ]
 
     thread_local = threading.local()
@@ -196,6 +208,7 @@ class LoggingContext(object):
 
         def __nonzero__(self):
             return False
+
         __bool__ = __nonzero__  # python3
 
     sentinel = Sentinel()
@@ -226,6 +239,8 @@ class LoggingContext(object):
             self.request = request
 
     def __str__(self):
+        if self.request:
+            return str(self.request)
         return "%s@%x" % (self.name, id(self))
 
     @classmethod
@@ -259,7 +274,8 @@ class LoggingContext(object):
         if self.previous_context != old_context:
             logger.warn(
                 "Expected previous context %r, found %r",
-                self.previous_context, old_context
+                self.previous_context,
+                old_context,
             )
         self.alive = True
 
@@ -274,20 +290,17 @@ class LoggingContext(object):
         current = self.set_current_context(self.previous_context)
         if current is not self:
             if current is self.sentinel:
-                logger.warn("Expected logging context %s has been lost", self)
+                logger.warning("Expected logging context %s was lost", self)
             else:
-                logger.warn(
-                    "Current logging context %s is not expected context %s",
-                    current,
-                    self
+                logger.warning(
+                    "Expected logging context %s but found %s", self, current
                 )
         self.previous_context = None
         self.alive = False
 
         # if we have a parent, pass our CPU usage stats on
-        if (
-            self.parent_context is not None
-            and hasattr(self.parent_context, '_resource_usage')
+        if self.parent_context is not None and hasattr(
+            self.parent_context, "_resource_usage"
         ):
             self.parent_context._resource_usage += self._resource_usage
 
@@ -320,15 +333,12 @@ class LoggingContext(object):
 
         # When we stop, let's record the cpu used since we started
         if not self.usage_start:
-            logger.warning(
-                "Called stop on logcontext %s without calling start", self,
-            )
+            logger.warning("Called stop on logcontext %s without calling start", self)
             return
 
-        usage_end = get_thread_resource_usage()
-
-        self._resource_usage.ru_utime += usage_end.ru_utime - self.usage_start.ru_utime
-        self._resource_usage.ru_stime += usage_end.ru_stime - self.usage_start.ru_stime
+        utime_delta, stime_delta = self._get_cputime()
+        self._resource_usage.ru_utime += utime_delta
+        self._resource_usage.ru_stime += stime_delta
 
         self.usage_start = None
 
@@ -346,13 +356,44 @@ class LoggingContext(object):
         # can include resource usage so far.
         is_main_thread = threading.current_thread() is self.main_thread
         if self.alive and self.usage_start and is_main_thread:
-            current = get_thread_resource_usage()
-            res.ru_utime += current.ru_utime - self.usage_start.ru_utime
-            res.ru_stime += current.ru_stime - self.usage_start.ru_stime
+            utime_delta, stime_delta = self._get_cputime()
+            res.ru_utime += utime_delta
+            res.ru_stime += stime_delta
 
         return res
 
+    def _get_cputime(self):
+        """Get the cpu usage time so far
+
+        Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
+        """
+        current = get_thread_resource_usage()
+
+        utime_delta = current.ru_utime - self.usage_start.ru_utime
+        stime_delta = current.ru_stime - self.usage_start.ru_stime
+
+        # sanity check
+        if utime_delta < 0:
+            logger.error(
+                "utime went backwards! %f < %f",
+                current.ru_utime,
+                self.usage_start.ru_utime,
+            )
+            utime_delta = 0
+
+        if stime_delta < 0:
+            logger.error(
+                "stime went backwards! %f < %f",
+                current.ru_stime,
+                self.usage_start.ru_stime,
+            )
+            stime_delta = 0
+
+        return utime_delta, stime_delta
+
     def add_database_transaction(self, duration_sec):
+        if duration_sec < 0:
+            raise ValueError("DB txn time can only be non-negative")
         self._resource_usage.db_txn_count += 1
         self._resource_usage.db_txn_duration_sec += duration_sec
 
@@ -363,6 +404,8 @@ class LoggingContext(object):
             sched_sec (float): number of seconds it took us to get a
                 connection
         """
+        if sched_sec < 0:
+            raise ValueError("DB scheduling time can only be non-negative")
         self._resource_usage.db_sched_duration_sec += sched_sec
 
     def record_event_fetch(self, event_count):
@@ -381,6 +424,7 @@ class LoggingContextFilter(logging.Filter):
         **defaults: Default values to avoid formatters complaining about
             missing fields
     """
+
     def __init__(self, **defaults):
         self.defaults = defaults
 
@@ -416,34 +460,30 @@ class PreserveLoggingContext(object):
 
     def __enter__(self):
         """Captures the current logging context"""
-        self.current_context = LoggingContext.set_current_context(
-            self.new_context
-        )
+        self.current_context = LoggingContext.set_current_context(self.new_context)
 
         if self.current_context:
             self.has_parent = self.current_context.previous_context is not None
             if not self.current_context.alive:
-                logger.debug(
-                    "Entering dead context: %s",
-                    self.current_context,
-                )
+                logger.debug("Entering dead context: %s", self.current_context)
 
     def __exit__(self, type, value, traceback):
         """Restores the current logging context"""
         context = LoggingContext.set_current_context(self.current_context)
 
         if context != self.new_context:
-            logger.warn(
-                "Unexpected logging context: %s is not %s",
-                context, self.new_context,
-            )
+            if context is LoggingContext.sentinel:
+                logger.warning("Expected logging context %s was lost", self.new_context)
+            else:
+                logger.warning(
+                    "Expected logging context %s but found %s",
+                    self.new_context,
+                    context,
+                )
 
         if self.current_context is not LoggingContext.sentinel:
             if not self.current_context.alive:
-                logger.debug(
-                    "Restoring dead context: %s",
-                    self.current_context,
-                )
+                logger.debug("Restoring dead context: %s", self.current_context)
 
 
 def nested_logging_context(suffix, parent_context=None):
@@ -470,15 +510,16 @@ def nested_logging_context(suffix, parent_context=None):
     if parent_context is None:
         parent_context = LoggingContext.current_context()
     return LoggingContext(
-        parent_context=parent_context,
-        request=parent_context.request + "-" + suffix,
+        parent_context=parent_context, request=parent_context.request + "-" + suffix
     )
 
 
 def preserve_fn(f):
     """Function decorator which wraps the function with run_in_background"""
+
     def g(*args, **kwargs):
         return run_in_background(f, *args, **kwargs)
+
     return g
 
 
@@ -498,7 +539,7 @@ def run_in_background(f, *args, **kwargs):
     current = LoggingContext.current_context()
     try:
         res = f(*args, **kwargs)
-    except:   # noqa: E722
+    except:  # noqa: E722
         # the assumption here is that the caller doesn't want to be disturbed
         # by synchronous exceptions, so let's turn them into Failures.
         return defer.fail()
@@ -635,6 +676,4 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         with LoggingContext(parent_context=logcontext):
             return f(*args, **kwargs)
 
-    return make_deferred_yieldable(
-        threads.deferToThreadPool(reactor, threadpool, g)
-    )
+    return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
index a46bc47ce3..fbf570c756 100644
--- a/synapse/util/logformatter.py
+++ b/synapse/util/logformatter.py
@@ -29,6 +29,7 @@ class LogFormatter(logging.Formatter):
     (Normally only stack frames between the point the exception was raised and
     where it was caught are logged).
     """
+
     def __init__(self, *args, **kwargs):
         super(LogFormatter, self).__init__(*args, **kwargs)
 
@@ -40,7 +41,7 @@ class LogFormatter(logging.Formatter):
         # check that we actually have an f_back attribute to work around
         # https://twistedmatrix.com/trac/ticket/9305
 
-        if tb and hasattr(tb.tb_frame, 'f_back'):
+        if tb and hasattr(tb.tb_frame, "f_back"):
             sio.write("Capture point (most recent call last):\n")
             traceback.print_stack(tb.tb_frame.f_back, None, sio)
 
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index ef31458226..7df0fa6087 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -44,7 +44,7 @@ def _log_debug_as_f(f, msg, msg_args):
             lineno=lineno,
             msg=msg,
             args=msg_args,
-            exc_info=None
+            exc_info=None,
         )
 
         logger.handle(record)
@@ -70,20 +70,11 @@ def log_function(f):
                     r = r[:50] + "..."
                 return r
 
-            func_args = [
-                "%s=%s" % (k, format(v)) for k, v in bound_args.items()
-            ]
+            func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()]
 
-            msg_args = {
-                "func_name": func_name,
-                "args": ", ".join(func_args)
-            }
+            msg_args = {"func_name": func_name, "args": ", ".join(func_args)}
 
-            _log_debug_as_f(
-                f,
-                "Invoked '%(func_name)s' with args: %(args)s",
-                msg_args
-            )
+            _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args)
 
         return f(*args, **kwargs)
 
@@ -103,19 +94,13 @@ def time_function(f):
         start = time.clock()
 
         try:
-            _log_debug_as_f(
-                f,
-                "[FUNC START] {%s-%d}",
-                (func_name, id),
-            )
+            _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id))
 
             r = f(*args, **kwargs)
         finally:
             end = time.clock()
             _log_debug_as_f(
-                f,
-                "[FUNC END] {%s-%d} %.3f sec",
-                (func_name, id, end - start,),
+                f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start)
             )
 
         return r
@@ -137,9 +122,8 @@ def trace_function(f):
         s = inspect.currentframe().f_back
 
         to_print = [
-            "\t%s:%s %s. Args: args=%s, kwargs=%s" % (
-                pathname, linenum, func_name, args, kwargs
-            )
+            "\t%s:%s %s. Args: args=%s, kwargs=%s"
+            % (pathname, linenum, func_name, args, kwargs)
         ]
         while s:
             if True or s.f_globals["__name__"].startswith("synapse"):
@@ -147,9 +131,7 @@ def trace_function(f):
                 args_string = inspect.formatargvalues(*inspect.getargvalues(s))
 
                 to_print.append(
-                    "\t%s:%d %s. Args: %s" % (
-                        filename, lineno, function, args_string
-                    )
+                    "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string)
                 )
 
             s = s.f_back
@@ -163,7 +145,7 @@ def trace_function(f):
             lineno=lineno,
             msg=msg,
             args=None,
-            exc_info=None
+            exc_info=None,
         )
 
         logger.handle(record)
@@ -182,13 +164,13 @@ def get_previous_frames():
             filename, lineno, function, _, _ = inspect.getframeinfo(s)
             args_string = inspect.formatargvalues(*inspect.getargvalues(s))
 
-            to_return.append("{{  %s:%d %s - Args: %s }}" % (
-                filename, lineno, function, args_string
-            ))
+            to_return.append(
+                "{{  %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string)
+            )
 
         s = s.f_back
 
-    return ", ". join(to_return)
+    return ", ".join(to_return)
 
 
 def get_previous_frame(ignore=[]):
@@ -201,7 +183,10 @@ def get_previous_frame(ignore=[]):
                 args_string = inspect.formatargvalues(*inspect.getargvalues(s))
 
                 return "{{  %s:%d %s - Args: %s }}" % (
-                    filename, lineno, function, args_string
+                    filename,
+                    lineno,
+                    function,
+                    args_string,
                 )
 
         s = s.f_back
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 628a2962d9..631654f297 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -74,27 +74,25 @@ def manhole(username, password, globals):
         twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
     """
     if not isinstance(password, bytes):
-        password = password.encode('ascii')
+        password = password.encode("ascii")
 
-    checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
-        **{username: password}
-    )
+    checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password})
 
     rlm = manhole_ssh.TerminalRealm()
     rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
-        SynapseManhole,
-        dict(globals, __name__="__console__")
+        SynapseManhole, dict(globals, __name__="__console__")
     )
 
     factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
-    factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY)
-    factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY)
+    factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY)
+    factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY)
 
     return factory
 
 
 class SynapseManhole(ColoredManhole):
     """Overrides connectionMade to create our own ManholeInterpreter"""
+
     def connectionMade(self):
         super(SynapseManhole, self).connectionMade()
 
@@ -127,7 +125,7 @@ class SynapseManholeInterpreter(ManholeInterpreter):
                 value = SyntaxError(msg, (filename, lineno, offset, line))
                 sys.last_value = value
         lines = traceback.format_exception_only(type, value)
-        self.write(''.join(lines))
+        self.write("".join(lines))
 
     def showtraceback(self):
         """Display the exception that just occurred.
@@ -140,6 +138,6 @@ class SynapseManholeInterpreter(ManholeInterpreter):
         try:
             # We remove the first stack item because it is our own code.
             lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next)
-            self.write(''.join(lines))
+            self.write("".join(lines))
         finally:
             last_tb = ei = None
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4b4ac5f6c7..01284d3cf8 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -30,25 +30,31 @@ block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"])
 block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"])
 
 block_ru_utime = Counter(
-    "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"])
+    "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]
+)
 
 block_ru_stime = Counter(
-    "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"])
+    "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]
+)
 
 block_db_txn_count = Counter(
-    "synapse_util_metrics_block_db_txn_count", "", ["block_name"])
+    "synapse_util_metrics_block_db_txn_count", "", ["block_name"]
+)
 
 # seconds spent waiting for db txns, excluding scheduling time, in this block
 block_db_txn_duration = Counter(
-    "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"])
+    "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]
+)
 
 # seconds spent waiting for a db connection, in this block
 block_db_sched_duration = Counter(
-    "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"])
+    "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
+)
 
 # Tracks the number of blocks currently active
 in_flight = InFlightGauge(
-    "synapse_util_metrics_block_in_flight", "",
+    "synapse_util_metrics_block_in_flight",
+    "",
     labels=["block_name"],
     sub_metrics=["real_time_max", "real_time_sum"],
 )
@@ -62,13 +68,18 @@ def measure_func(name):
             with Measure(self.clock, name):
                 r = yield func(self, *args, **kwargs)
             defer.returnValue(r)
+
         return measured_func
+
     return wrapper
 
 
 class Measure(object):
     __slots__ = [
-        "clock", "name", "start_context", "start",
+        "clock",
+        "name",
+        "start_context",
+        "start",
         "created_context",
         "start_usage",
     ]
@@ -108,7 +119,9 @@ class Measure(object):
         if context != self.start_context:
             logger.warn(
                 "Context has unexpectedly changed from '%s' to '%s'. (%r)",
-                self.start_context, context, self.name
+                self.start_context,
+                context,
+                self.name,
             )
             return
 
@@ -126,8 +139,7 @@ class Measure(object):
             block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
         except ValueError:
             logger.warn(
-                "Failed to save metrics! OLD: %r, NEW: %r",
-                self.start_usage, current
+                "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current
             )
 
         if self.created_context:
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 4288312b8a..522acd5aa8 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -28,15 +28,13 @@ def load_module(provider):
     """
     # We need to import the module, and then pick the class out of
     # that, so we split based on the last dot.
-    module, clz = provider['module'].rsplit(".", 1)
+    module, clz = provider["module"].rsplit(".", 1)
     module = importlib.import_module(module)
     provider_class = getattr(module, clz)
 
     try:
         provider_config = provider_class.parse_config(provider["config"])
     except Exception as e:
-        raise ConfigError(
-            "Failed to parse config for %r: %r" % (provider['module'], e)
-        )
+        raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
 
     return provider_class, provider_config
diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
index a6c30e5265..c8bcbe297a 100644
--- a/synapse/util/msisdn.py
+++ b/synapse/util/msisdn.py
@@ -36,6 +36,6 @@ def phone_number_to_msisdn(country, number):
         phoneNumber = phonenumbers.parse(number, country)
     except phonenumbers.NumberParseException:
         raise SynapseError(400, "Unable to parse phone number")
-    return phonenumbers.format_number(
-        phoneNumber, phonenumbers.PhoneNumberFormat.E164
-    )[1:]
+    return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[
+        1:
+    ]
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index b146d137f4..06defa8199 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -56,11 +56,7 @@ class FederationRateLimiter(object):
             _PerHostRatelimiter
         """
         return self.ratelimiters.setdefault(
-            host,
-            _PerHostRatelimiter(
-                clock=self.clock,
-                config=self._config,
-            )
+            host, _PerHostRatelimiter(clock=self.clock, config=self._config)
         ).ratelimit()
 
 
@@ -112,8 +108,7 @@ class _PerHostRatelimiter(object):
 
         # remove any entries from request_times which aren't within the window
         self.request_times[:] = [
-            r for r in self.request_times
-            if time_now - r < self.window_size
+            r for r in self.request_times if time_now - r < self.window_size
         ]
 
         # reject the request if we already have too many queued up (either
@@ -121,9 +116,7 @@ class _PerHostRatelimiter(object):
         queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
         if queue_size > self.reject_limit:
             raise LimitExceededError(
-                retry_after_ms=int(
-                    self.window_size / self.sleep_limit
-                ),
+                retry_after_ms=int(self.window_size / self.sleep_limit)
             )
 
         self.request_times.append(time_now)
@@ -143,22 +136,18 @@ class _PerHostRatelimiter(object):
 
         logger.debug(
             "Ratelimit [%s]: len(self.request_times)=%d",
-            id(request_id), len(self.request_times),
+            id(request_id),
+            len(self.request_times),
         )
 
         if len(self.request_times) > self.sleep_limit:
-            logger.debug(
-                "Ratelimiter: sleeping request for %f sec", self.sleep_sec,
-            )
+            logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
             ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
 
             self.sleeping_requests.add(request_id)
 
             def on_wait_finished(_):
-                logger.debug(
-                    "Ratelimit [%s]: Finished sleeping",
-                    id(request_id),
-                )
+                logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
                 self.sleeping_requests.discard(request_id)
                 queue_defer = queue_request()
                 return queue_defer
@@ -168,10 +157,7 @@ class _PerHostRatelimiter(object):
             ret_defer = queue_request()
 
         def on_start(r):
-            logger.debug(
-                "Ratelimit [%s]: Processing req",
-                id(request_id),
-            )
+            logger.debug("Ratelimit [%s]: Processing req", id(request_id))
             self.current_processing.add(request_id)
             return r
 
@@ -193,10 +179,7 @@ class _PerHostRatelimiter(object):
         return make_deferred_yieldable(ret_defer)
 
     def _on_exit(self, request_id):
-        logger.debug(
-            "Ratelimit [%s]: Processed req",
-            id(request_id),
-        )
+        logger.debug("Ratelimit [%s]: Processed req", id(request_id))
         self.current_processing.discard(request_id)
         try:
             # start processing the next item on the queue.
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 26cce7d197..1a77456498 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -46,8 +46,7 @@ class NotRetryingDestination(Exception):
 
 
 @defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, ignore_backoff=False,
-                      **kwargs):
+def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
     """For a given destination check if we have previously failed to
     send a request there and are waiting before retrying the destination.
     If we are not ready to retry the destination, this will raise a
@@ -60,8 +59,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
         clock (synapse.util.clock): timing source
         store (synapse.storage.transactions.TransactionStore): datastore
         ignore_backoff (bool): true to ignore the historical backoff data and
-            try the request anyway. We will still update the next
-            retry_interval on success/failure.
+            try the request anyway. We will still reset the retry_interval on success.
 
     Example usage:
 
@@ -75,13 +73,12 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
     """
     retry_last_ts, retry_interval = (0, 0)
 
-    retry_timings = yield store.get_destination_retry_timings(
-        destination
-    )
+    retry_timings = yield store.get_destination_retry_timings(destination)
 
     if retry_timings:
         retry_last_ts, retry_interval = (
-            retry_timings["retry_last_ts"], retry_timings["retry_interval"]
+            retry_timings["retry_last_ts"],
+            retry_timings["retry_interval"],
         )
 
         now = int(clock.time_msec())
@@ -93,22 +90,36 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
                 destination=destination,
             )
 
+    # if we are ignoring the backoff data, we should also not increment the backoff
+    # when we get another failure - otherwise a server can very quickly reach the
+    # maximum backoff even though it might only have been down briefly
+    backoff_on_failure = not ignore_backoff
+
     defer.returnValue(
         RetryDestinationLimiter(
             destination,
             clock,
             store,
             retry_interval,
+            backoff_on_failure=backoff_on_failure,
             **kwargs
         )
     )
 
 
 class RetryDestinationLimiter(object):
-    def __init__(self, destination, clock, store, retry_interval,
-                 min_retry_interval=10 * 60 * 1000,
-                 max_retry_interval=24 * 60 * 60 * 1000,
-                 multiplier_retry_interval=5, backoff_on_404=False):
+    def __init__(
+        self,
+        destination,
+        clock,
+        store,
+        retry_interval,
+        min_retry_interval=10 * 60 * 1000,
+        max_retry_interval=24 * 60 * 60 * 1000,
+        multiplier_retry_interval=5,
+        backoff_on_404=False,
+        backoff_on_failure=True,
+    ):
         """Marks the destination as "down" if an exception is thrown in the
         context, except for CodeMessageException with code < 500.
 
@@ -128,6 +139,9 @@ class RetryDestinationLimiter(object):
             multiplier_retry_interval (int): The multiplier to use to increase
                 the retry interval after a failed request.
             backoff_on_404 (bool): Back off if we get a 404
+
+            backoff_on_failure (bool): set to False if we should not increase the
+                retry interval on a failure.
         """
         self.clock = clock
         self.store = store
@@ -138,6 +152,7 @@ class RetryDestinationLimiter(object):
         self.max_retry_interval = max_retry_interval
         self.multiplier_retry_interval = multiplier_retry_interval
         self.backoff_on_404 = backoff_on_404
+        self.backoff_on_failure = backoff_on_failure
 
     def __enter__(self):
         pass
@@ -173,10 +188,13 @@ class RetryDestinationLimiter(object):
             if not self.retry_interval:
                 return
 
-            logger.debug("Connection to %s was successful; clearing backoff",
-                         self.destination)
+            logger.debug(
+                "Connection to %s was successful; clearing backoff", self.destination
+            )
             retry_last_ts = 0
             self.retry_interval = 0
+        elif not self.backoff_on_failure:
+            return
         else:
             # We couldn't connect.
             if self.retry_interval:
@@ -190,7 +208,10 @@ class RetryDestinationLimiter(object):
 
             logger.info(
                 "Connection to %s was unsuccessful (%s(%s)); backoff now %i",
-                self.destination, exc_type, exc_val, self.retry_interval
+                self.destination,
+                exc_type,
+                exc_val,
+                self.retry_interval,
             )
             retry_last_ts = int(self.clock.time_msec())
 
@@ -201,9 +222,7 @@ class RetryDestinationLimiter(object):
                     self.destination, retry_last_ts, self.retry_interval
                 )
             except Exception:
-                logger.exception(
-                    "Failed to store destination_retry_timings",
-                )
+                logger.exception("Failed to store destination_retry_timings")
 
         # we deliberately do this in the background.
         synapse.util.logcontext.run_in_background(store_retry_timings)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 69dffd8244..982c6d81ca 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -20,9 +20,7 @@ import six
 from six import PY2, PY3
 from six.moves import range
 
-_string_with_symbols = (
-    string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
-)
+_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
 
 # random_string and random_string_with_symbols are used for a range of things,
 # some cryptographically important, some less so. We use SystemRandom to make sure
@@ -31,13 +29,11 @@ rand = random.SystemRandom()
 
 
 def random_string(length):
-    return ''.join(rand.choice(string.ascii_letters) for _ in range(length))
+    return "".join(rand.choice(string.ascii_letters) for _ in range(length))
 
 
 def random_string_with_symbols(length):
-    return ''.join(
-        rand.choice(_string_with_symbols) for _ in range(length)
-    )
+    return "".join(rand.choice(_string_with_symbols) for _ in range(length))
 
 
 def is_ascii(s):
@@ -45,7 +41,7 @@ def is_ascii(s):
     if PY3:
         if isinstance(s, bytes):
             try:
-                s.decode('ascii').encode('ascii')
+                s.decode("ascii").encode("ascii")
             except UnicodeDecodeError:
                 return False
             except UnicodeEncodeError:
@@ -104,12 +100,12 @@ def exception_to_unicode(e):
     # and instead look at what is in the args member.
 
     if len(e.args) == 0:
-        return u""
+        return ""
     elif len(e.args) > 1:
         return six.text_type(repr(e.args))
 
     msg = e.args[0]
     if isinstance(msg, bytes):
-        return msg.decode('utf-8', errors='replace')
+        return msg.decode("utf-8", errors="replace")
     else:
         return msg
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 75efa0117b..3ec1dfb0c2 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -35,11 +35,13 @@ def check_3pid_allowed(hs, medium, address):
         for constraint in hs.config.allowed_local_3pids:
             logger.debug(
                 "Checking 3PID %s (%s) against %s (%s)",
-                address, medium, constraint['pattern'], constraint['medium'],
+                address,
+                medium,
+                constraint["pattern"],
+                constraint["medium"],
             )
-            if (
-                medium == constraint['medium'] and
-                re.match(constraint['pattern'], address)
+            if medium == constraint["medium"] and re.match(
+                constraint["pattern"], address
             ):
                 return True
     else:
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 3baba3225a..a4d9a462f7 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -23,44 +23,53 @@ logger = logging.getLogger(__name__)
 
 def get_version_string(module):
     try:
-        null = open(os.devnull, 'w')
+        null = open(os.devnull, "w")
         cwd = os.path.dirname(os.path.abspath(module.__file__))
         try:
-            git_branch = subprocess.check_output(
-                ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
-                stderr=null,
-                cwd=cwd,
-            ).strip().decode('ascii')
+            git_branch = (
+                subprocess.check_output(
+                    ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd
+                )
+                .strip()
+                .decode("ascii")
+            )
             git_branch = "b=" + git_branch
         except subprocess.CalledProcessError:
             git_branch = ""
 
         try:
-            git_tag = subprocess.check_output(
-                ['git', 'describe', '--exact-match'],
-                stderr=null,
-                cwd=cwd,
-            ).strip().decode('ascii')
+            git_tag = (
+                subprocess.check_output(
+                    ["git", "describe", "--exact-match"], stderr=null, cwd=cwd
+                )
+                .strip()
+                .decode("ascii")
+            )
             git_tag = "t=" + git_tag
         except subprocess.CalledProcessError:
             git_tag = ""
 
         try:
-            git_commit = subprocess.check_output(
-                ['git', 'rev-parse', '--short', 'HEAD'],
-                stderr=null,
-                cwd=cwd,
-            ).strip().decode('ascii')
+            git_commit = (
+                subprocess.check_output(
+                    ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd
+                )
+                .strip()
+                .decode("ascii")
+            )
         except subprocess.CalledProcessError:
             git_commit = ""
 
         try:
             dirty_string = "-this_is_a_dirty_checkout"
-            is_dirty = subprocess.check_output(
-                ['git', 'describe', '--dirty=' + dirty_string],
-                stderr=null,
-                cwd=cwd,
-            ).strip().decode('ascii').endswith(dirty_string)
+            is_dirty = (
+                subprocess.check_output(
+                    ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd
+                )
+                .strip()
+                .decode("ascii")
+                .endswith(dirty_string)
+            )
 
             git_dirty = "dirty" if is_dirty else ""
         except subprocess.CalledProcessError:
@@ -68,16 +77,10 @@ def get_version_string(module):
 
         if git_branch or git_tag or git_commit or git_dirty:
             git_version = ",".join(
-                s for s in
-                (git_branch, git_tag, git_commit, git_dirty,)
-                if s
+                s for s in (git_branch, git_tag, git_commit, git_dirty) if s
             )
 
-            return (
-                "%s (%s)" % (
-                    module.__version__, git_version,
-                )
-            )
+            return "%s (%s)" % (module.__version__, git_version)
     except Exception as e:
         logger.info("Failed to check for git repository: %s", e)
 
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 7a9e45aca9..9bf6a44f75 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -69,9 +69,7 @@ class WheelTimer(object):
 
         # Add empty entries between the end of the current list and when we want
         # to insert. This ensures there are no gaps.
-        self.entries.extend(
-            _Entry(key) for key in range(last_key, then_key + 1)
-        )
+        self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1))
 
         self.entries[-1].queue.append(obj)