summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py8
-rw-r--r--synapse/util/async_helpers.py4
-rw-r--r--synapse/util/caches/__init__.py17
-rw-r--r--synapse/util/caches/descriptors.py94
-rw-r--r--synapse/util/caches/response_cache.py2
-rw-r--r--synapse/util/metrics.py2
-rw-r--r--synapse/util/retryutils.py16
-rw-r--r--synapse/util/versionstring.py23
8 files changed, 113 insertions, 53 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index f506b2a695..7856353002 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -49,7 +49,7 @@ class Clock(object):
         with context.PreserveLoggingContext():
             self._reactor.callLater(seconds, d.callback, seconds)
             res = yield d
-        defer.returnValue(res)
+        return res
 
     def time(self):
         """Returns the current system time in seconds since epoch."""
@@ -59,7 +59,7 @@ class Clock(object):
         """Returns the current system time in miliseconds since epoch."""
         return int(self.time() * 1000)
 
-    def looping_call(self, f, msec):
+    def looping_call(self, f, msec, *args, **kwargs):
         """Call a function repeatedly.
 
         Waits `msec` initially before calling `f` for the first time.
@@ -70,8 +70,10 @@ class Clock(object):
         Args:
             f(function): The function to call repeatedly.
             msec(float): How long to wait between calls in milliseconds.
+            *args: Postional arguments to pass to function.
+            **kwargs: Key arguments to pass to function.
         """
-        call = task.LoopingCall(f)
+        call = task.LoopingCall(f, *args, **kwargs)
         call.clock = self._reactor
         d = call.start(msec / 1000.0, now=False)
         d.addErrback(log_failure, "Looping call died", consumeErrors=False)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 58a6b8764f..f1c46836b1 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -366,7 +366,7 @@ class ReadWriteLock(object):
                 new_defer.callback(None)
                 self.key_to_current_readers.get(key, set()).discard(new_defer)
 
-        defer.returnValue(_ctx_manager())
+        return _ctx_manager()
 
     @defer.inlineCallbacks
     def write(self, key):
@@ -396,7 +396,7 @@ class ReadWriteLock(object):
                 if self.key_to_current_writer[key] == new_defer:
                     self.key_to_current_writer.pop(key)
 
-        defer.returnValue(_ctx_manager())
+        return _ctx_manager()
 
 
 def _cancelled_to_timed_out_error(value, timeout):
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8271229015..b50e3503f0 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -51,7 +52,19 @@ response_cache_evicted = Gauge(
 response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
 
 
-def register_cache(cache_type, cache_name, cache):
+def register_cache(cache_type, cache_name, cache, collect_callback=None):
+    """Register a cache object for metric collection.
+
+    Args:
+        cache_type (str):
+        cache_name (str): name of the cache
+        cache (object): cache itself
+        collect_callback (callable|None): if not None, a function which is called during
+            metric collection to update additional metrics.
+
+    Returns:
+        CacheMetric: an object which provides inc_{hits,misses,evictions} methods
+    """
 
     # Check if the metric is already registered. Unregister it, if so.
     # This usually happens during tests, as at runtime these caches are
@@ -90,6 +103,8 @@ def register_cache(cache_type, cache_name, cache):
                     cache_hits.labels(cache_name).set(self.hits)
                     cache_evicted.labels(cache_name).set(self.evicted_size)
                     cache_total.labels(cache_name).set(self.hits + self.misses)
+                if collect_callback:
+                    collect_callback()
             except Exception as e:
                 logger.warn("Error calculating metrics for %s: %s", cache_name, e)
                 raise
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 675db2f448..43f66ec4be 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,8 +19,9 @@ import logging
 import threading
 from collections import namedtuple
 
-import six
-from six import itervalues, string_types
+from six import itervalues
+
+from prometheus_client import Gauge
 
 from twisted.internet import defer
 
@@ -30,13 +31,18 @@ from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.stringutils import to_ascii
 
 from . import register_cache
 
 logger = logging.getLogger(__name__)
 
 
+cache_pending_metric = Gauge(
+    "synapse_util_caches_cache_pending",
+    "Number of lookups currently pending for this cache",
+    ["name"],
+)
+
 _CacheSentinel = object()
 
 
@@ -82,11 +88,19 @@ class Cache(object):
         self.name = name
         self.keylen = keylen
         self.thread = None
-        self.metrics = register_cache("cache", name, self.cache)
+        self.metrics = register_cache(
+            "cache",
+            name,
+            self.cache,
+            collect_callback=self._metrics_collection_callback,
+        )
 
     def _on_evicted(self, evicted_count):
         self.metrics.inc_evictions(evicted_count)
 
+    def _metrics_collection_callback(self):
+        cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
+
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
@@ -108,7 +122,7 @@ class Cache(object):
             update_metrics (bool): whether to update the cache hit rate metrics
 
         Returns:
-            Either a Deferred or the raw result
+            Either an ObservableDeferred or the raw result
         """
         callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _CacheSentinel)
@@ -132,9 +146,14 @@ class Cache(object):
             return default
 
     def set(self, key, value, callback=None):
+        if not isinstance(value, defer.Deferred):
+            raise TypeError("not a Deferred")
+
         callbacks = [callback] if callback else []
         self.check_thread()
-        entry = CacheEntry(deferred=value, callbacks=callbacks)
+        observable = ObservableDeferred(value, consumeErrors=True)
+        observer = defer.maybeDeferred(observable.observe)
+        entry = CacheEntry(deferred=observable, callbacks=callbacks)
 
         existing_entry = self._pending_deferred_cache.pop(key, None)
         if existing_entry:
@@ -142,20 +161,31 @@ class Cache(object):
 
         self._pending_deferred_cache[key] = entry
 
-        def shuffle(result):
+        def compare_and_pop():
+            """Check if our entry is still the one in _pending_deferred_cache, and
+            if so, pop it.
+
+            Returns true if the entries matched.
+            """
             existing_entry = self._pending_deferred_cache.pop(key, None)
             if existing_entry is entry:
+                return True
+
+            # oops, the _pending_deferred_cache has been updated since
+            # we started our query, so we are out of date.
+            #
+            # Better put back whatever we took out. (We do it this way
+            # round, rather than peeking into the _pending_deferred_cache
+            # and then removing on a match, to make the common case faster)
+            if existing_entry is not None:
+                self._pending_deferred_cache[key] = existing_entry
+
+            return False
+
+        def cb(result):
+            if compare_and_pop():
                 self.cache.set(key, result, entry.callbacks)
             else:
-                # oops, the _pending_deferred_cache has been updated since
-                # we started our query, so we are out of date.
-                #
-                # Better put back whatever we took out. (We do it this way
-                # round, rather than peeking into the _pending_deferred_cache
-                # and then removing on a match, to make the common case faster)
-                if existing_entry is not None:
-                    self._pending_deferred_cache[key] = existing_entry
-
                 # we're not going to put this entry into the cache, so need
                 # to make sure that the invalidation callbacks are called.
                 # That was probably done when _pending_deferred_cache was
@@ -163,9 +193,16 @@ class Cache(object):
                 # `invalidate` being previously called, in which case it may
                 # not have been. Either way, let's double-check now.
                 entry.invalidate()
-            return result
 
-        entry.deferred.addCallback(shuffle)
+        def eb(_fail):
+            compare_and_pop()
+            entry.invalidate()
+
+        # once the deferred completes, we can move the entry from the
+        # _pending_deferred_cache to the real cache.
+        #
+        observer.addCallbacks(cb, eb)
+        return observable
 
     def prefill(self, key, value, callback=None):
         callbacks = [callback] if callback else []
@@ -289,7 +326,7 @@ class CacheDescriptor(_CacheDescriptorBase):
         def foo(self, key, cache_context):
             r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
-            defer.returnValue(r1 + r2)
+            return r1 + r2
 
     Args:
         num_args (int): number of positional arguments (excluding ``self`` and
@@ -398,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
 
                 ret.addErrback(onErr)
 
-                # If our cache_key is a string on py2, try to convert to ascii
-                # to save a bit of space in large caches. Py3 does this
-                # internally automatically.
-                if six.PY2 and isinstance(cache_key, string_types):
-                    cache_key = to_ascii(cache_key)
-
-                result_d = ObservableDeferred(ret, consumeErrors=True)
-                cache.set(cache_key, result_d, callback=invalidate_callback)
+                result_d = cache.set(cache_key, ret, callback=invalidate_callback)
                 observer = result_d.observe()
 
-            if isinstance(observer, defer.Deferred):
-                return make_deferred_yieldable(observer)
-            else:
-                return observer
+            return make_deferred_yieldable(observer)
 
         if self.num_args == 1:
             wrapped.invalidate = lambda key: cache.invalidate(key[0])
@@ -527,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                     missing.add(arg)
 
             if missing:
-                # we need an observable deferred for each entry in the list,
+                # we need a deferred for each entry in the list,
                 # which we put in the cache. Each deferred resolves with the
                 # relevant result for that key.
                 deferreds_map = {}
@@ -535,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                     deferred = defer.Deferred()
                     deferreds_map[arg] = deferred
                     key = arg_to_cache_key(arg)
-                    observable = ObservableDeferred(deferred)
-                    cache.set(key, observable, callback=invalidate_callback)
+                    cache.set(key, deferred, callback=invalidate_callback)
 
                 def complete_all(res):
                     # the wrapped function has completed. It returns a
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index d6908e169d..82d3eefe0e 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -121,7 +121,7 @@ class ResponseCache(object):
             @defer.inlineCallbacks
             def handle_request(request):
                 # etc
-                defer.returnValue(result)
+                return result
 
             result = yield response_cache.wrap(
                 key,
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index c30b6de19c..0910930c21 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -67,7 +67,7 @@ def measure_func(name):
         def measured_func(self, *args, **kwargs):
             with Measure(self.clock, name):
                 r = yield func(self, *args, **kwargs)
-            defer.returnValue(r)
+            return r
 
         return measured_func
 
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index d8d0ceae51..0862b5ca5a 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -95,15 +95,13 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
     # maximum backoff even though it might only have been down briefly
     backoff_on_failure = not ignore_backoff
 
-    defer.returnValue(
-        RetryDestinationLimiter(
-            destination,
-            clock,
-            store,
-            retry_interval,
-            backoff_on_failure=backoff_on_failure,
-            **kwargs
-        )
+    return RetryDestinationLimiter(
+        destination,
+        clock,
+        store,
+        retry_interval,
+        backoff_on_failure=backoff_on_failure,
+        **kwargs
     )
 
 
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index a4d9a462f7..fa404b9d75 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -22,6 +22,23 @@ logger = logging.getLogger(__name__)
 
 
 def get_version_string(module):
+    """Given a module calculate a git-aware version string for it.
+
+    If called on a module not in a git checkout will return `__verison__`.
+
+    Args:
+        module (module)
+
+    Returns:
+        str
+    """
+
+    cached_version = getattr(module, "_synapse_version_string_cache", None)
+    if cached_version:
+        return cached_version
+
+    version_string = module.__version__
+
     try:
         null = open(os.devnull, "w")
         cwd = os.path.dirname(os.path.abspath(module.__file__))
@@ -80,8 +97,10 @@ def get_version_string(module):
                 s for s in (git_branch, git_tag, git_commit, git_dirty) if s
             )
 
-            return "%s (%s)" % (module.__version__, git_version)
+            version_string = "%s (%s)" % (module.__version__, git_version)
     except Exception as e:
         logger.info("Failed to check for git repository: %s", e)
 
-    return module.__version__
+    module._synapse_version_string_cache = version_string
+
+    return version_string