diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 756d8ffa32..fc11e26623 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.errors import SynapseError
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -21,12 +20,9 @@ from twisted.internet import defer, reactor, task
import time
import logging
-logger = logging.getLogger(__name__)
-
+from itertools import islice
-class DeferredTimedOutError(SynapseError):
- def __init__(self):
- super(DeferredTimedOutError, self).__init__(504, "Timed out")
+logger = logging.getLogger(__name__)
def unwrapFirstError(failure):
@@ -86,52 +82,18 @@ class Clock(object):
if not ignore_errs:
raise
- def time_bound_deferred(self, given_deferred, time_out):
- if given_deferred.called:
- return given_deferred
-
- ret_deferred = defer.Deferred()
-
- def timed_out_fn():
- e = DeferredTimedOutError()
-
- try:
- ret_deferred.errback(e)
- except Exception:
- pass
-
- try:
- given_deferred.cancel()
- except Exception:
- pass
- timer = None
+def batch_iter(iterable, size):
+ """batch an iterable up into tuples with a maximum size
- def cancel(res):
- try:
- self.cancel_call_later(timer)
- except Exception:
- pass
- return res
+ Args:
+ iterable (iterable): the iterable to slice
+ size (int): the maximum batch size
- ret_deferred.addBoth(cancel)
-
- def success(res):
- try:
- ret_deferred.callback(res)
- except Exception:
- pass
-
- return res
-
- def err(res):
- try:
- ret_deferred.errback(res)
- except Exception:
- pass
-
- given_deferred.addCallbacks(callback=success, errback=err)
-
- timer = self.call_later(time_out, timed_out_fn)
-
- return ret_deferred
+ Returns:
+ an iterator over the chunks
+ """
+ # make sure we can deal with iterables like lists too
+ sourceiter = iter(iterable)
+ # call islice until it returns an empty tuple
+ return iter(lambda: tuple(islice(sourceiter, size)), ())
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 0729bb2863..9dd4e6b5bc 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -15,9 +15,11 @@
from twisted.internet import defer, reactor
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
from .logcontext import (
- PreserveLoggingContext, make_deferred_yieldable, preserve_fn
+ PreserveLoggingContext, make_deferred_yieldable, run_in_background
)
from synapse.util import logcontext, unwrapFirstError
@@ -25,6 +27,8 @@ from contextlib import contextmanager
import logging
+from six.moves import range
+
logger = logging.getLogger(__name__)
@@ -156,13 +160,13 @@ def concurrently_execute(func, args, limit):
def _concurrently_execute_inner():
try:
while True:
- yield func(it.next())
+ yield func(next(it))
except StopIteration:
pass
return logcontext.make_deferred_yieldable(defer.gatherResults([
- preserve_fn(_concurrently_execute_inner)()
- for _ in xrange(limit)
+ run_in_background(_concurrently_execute_inner)
+ for _ in range(limit)
], consumeErrors=True)).addErrback(unwrapFirstError)
@@ -392,3 +396,68 @@ class ReadWriteLock(object):
self.key_to_current_writer.pop(key)
defer.returnValue(_ctx_manager())
+
+
+class DeferredTimeoutError(Exception):
+ """
+ This error is raised by default when a L{Deferred} times out.
+ """
+
+
+def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
+ """
+ Add a timeout to a deferred by scheduling it to be cancelled after
+ timeout seconds.
+
+ This is essentially a backport of deferred.addTimeout, which was introduced
+ in twisted 16.5.
+
+ If the deferred gets timed out, it errbacks with a DeferredTimeoutError,
+ unless a cancelable function was passed to its initialization or unless
+ a different on_timeout_cancel callable is provided.
+
+ Args:
+ deferred (defer.Deferred): deferred to be timed out
+ timeout (Number): seconds to time out after
+
+ on_timeout_cancel (callable): A callable which is called immediately
+ after the deferred times out, and not if this deferred is
+ otherwise cancelled before the timeout.
+
+ It takes an arbitrary value, which is the value of the deferred at
+ that exact point in time (probably a CancelledError Failure), and
+ the timeout.
+
+ The default callable (if none is provided) will translate a
+ CancelledError Failure into a DeferredTimeoutError.
+ """
+ timed_out = [False]
+
+ def time_it_out():
+ timed_out[0] = True
+ deferred.cancel()
+
+ delayed_call = reactor.callLater(timeout, time_it_out)
+
+ def convert_cancelled(value):
+ if timed_out[0]:
+ to_call = on_timeout_cancel or _cancelled_to_timed_out_error
+ return to_call(value, timeout)
+ return value
+
+ deferred.addBoth(convert_cancelled)
+
+ def cancel_timeout(result):
+ # stop the pending call to cancel the deferred if it's been fired
+ if delayed_call.active():
+ delayed_call.cancel()
+ return result
+
+ deferred.addBoth(cancel_timeout)
+
+
+def _cancelled_to_timed_out_error(value, timeout):
+ if isinstance(value, failure.Failure):
+ value.trap(CancelledError)
+ raise DeferredTimeoutError(timeout, "Deferred")
+ return value
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 4adae96681..183faf75a1 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -13,28 +13,77 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.metrics
+from prometheus_client.core import Gauge, REGISTRY, GaugeMetricFamily
+
import os
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
+from six.moves import intern
+import six
-metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
caches_by_name = {}
-# cache_counter = metrics.register_cache(
-# "cache",
-# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
-# labels=["name"],
-# )
-
-
-def register_cache(name, cache):
- caches_by_name[name] = cache
- return metrics.register_cache(
- "cache",
- lambda: len(cache),
- name,
- )
+collectors_by_name = {}
+
+cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
+cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
+cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"])
+cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
+
+response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
+response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
+response_cache_evicted = Gauge(
+ "synapse_util_caches_response_cache:evicted_size", "", ["name"]
+)
+response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
+
+
+def register_cache(cache_type, cache_name, cache):
+
+ # Check if the metric is already registered. Unregister it, if so.
+ # This usually happens during tests, as at runtime these caches are
+ # effectively singletons.
+ metric_name = "cache_%s_%s" % (cache_type, cache_name)
+ if metric_name in collectors_by_name.keys():
+ REGISTRY.unregister(collectors_by_name[metric_name])
+
+ class CacheMetric(object):
+
+ hits = 0
+ misses = 0
+ evicted_size = 0
+
+ def inc_hits(self):
+ self.hits += 1
+
+ def inc_misses(self):
+ self.misses += 1
+
+ def inc_evictions(self, size=1):
+ self.evicted_size += size
+
+ def describe(self):
+ return []
+
+ def collect(self):
+ if cache_type == "response_cache":
+ response_cache_size.labels(cache_name).set(len(cache))
+ response_cache_hits.labels(cache_name).set(self.hits)
+ response_cache_evicted.labels(cache_name).set(self.evicted_size)
+ response_cache_total.labels(cache_name).set(self.hits + self.misses)
+ else:
+ cache_size.labels(cache_name).set(len(cache))
+ cache_hits.labels(cache_name).set(self.hits)
+ cache_evicted.labels(cache_name).set(self.evicted_size)
+ cache_total.labels(cache_name).set(self.hits + self.misses)
+
+ yield GaugeMetricFamily("__unused", "")
+
+ metric = CacheMetric()
+ REGISTRY.register(metric)
+ caches_by_name[cache_name] = cache
+ collectors_by_name[metric_name] = metric
+ return metric
KNOWN_KEYS = {
@@ -66,7 +115,9 @@ def intern_string(string):
return None
try:
- string = string.encode("ascii")
+ if six.PY2:
+ string = string.encode("ascii")
+
return intern(string)
except UnicodeEncodeError:
return string
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index bf3a66eae4..8a9dcb2fc2 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -39,12 +40,11 @@ _CacheSentinel = object()
class CacheEntry(object):
__slots__ = [
- "deferred", "sequence", "callbacks", "invalidated"
+ "deferred", "callbacks", "invalidated"
]
- def __init__(self, deferred, sequence, callbacks):
+ def __init__(self, deferred, callbacks):
self.deferred = deferred
- self.sequence = sequence
self.callbacks = set(callbacks)
self.invalidated = False
@@ -62,7 +62,6 @@ class Cache(object):
"max_entries",
"name",
"keylen",
- "sequence",
"thread",
"metrics",
"_pending_deferred_cache",
@@ -80,9 +79,8 @@ class Cache(object):
self.name = name
self.keylen = keylen
- self.sequence = 0
self.thread = None
- self.metrics = register_cache(name, self.cache)
+ self.metrics = register_cache("cache", name, self.cache)
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
@@ -113,11 +111,10 @@ class Cache(object):
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
- if val.sequence == self.sequence:
- val.callbacks.update(callbacks)
- if update_metrics:
- self.metrics.inc_hits()
- return val.deferred
+ val.callbacks.update(callbacks)
+ if update_metrics:
+ self.metrics.inc_hits()
+ return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
@@ -137,12 +134,9 @@ class Cache(object):
self.check_thread()
entry = CacheEntry(
deferred=value,
- sequence=self.sequence,
callbacks=callbacks,
)
- entry.callbacks.update(callbacks)
-
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
@@ -150,13 +144,25 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
def shuffle(result):
- if self.sequence == entry.sequence:
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry is entry:
- self.cache.set(key, result, entry.callbacks)
- else:
- entry.invalidate()
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry is entry:
+ self.cache.set(key, result, entry.callbacks)
else:
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ # we're not going to put this entry into the cache, so need
+ # to make sure that the invalidation callbacks are called.
+ # That was probably done when _pending_deferred_cache was
+ # updated, but it's possible that `set` was called without
+ # `invalidate` being previously called, in which case it may
+ # not have been. Either way, let's double-check now.
entry.invalidate()
return result
@@ -168,25 +174,29 @@ class Cache(object):
def invalidate(self, key):
self.check_thread()
+ self.cache.pop(key, None)
- # Increment the sequence number so that any SELECT statements that
- # raced with the INSERT don't update the cache (SYN-369)
- self.sequence += 1
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, which will (a) stop it being returned
+ # for future queries and (b) stop it being persisted as a proper entry
+ # in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
+
+ # run the invalidation callbacks now, rather than waiting for the
+ # deferred to resolve.
if entry:
entry.invalidate()
- self.cache.pop(key, None)
-
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
- self.sequence += 1
self.cache.del_multi(key)
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
@@ -194,8 +204,10 @@ class Cache(object):
def invalidate_all(self):
self.check_thread()
- self.sequence += 1
self.cache.clear()
+ for entry in self._pending_deferred_cache.itervalues():
+ entry.invalidate()
+ self._pending_deferred_cache.clear()
class _CacheDescriptorBase(object):
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index d4105822b3..bdc21e348f 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -55,7 +55,7 @@ class DictionaryCache(object):
__slots__ = []
self.sentinel = Sentinel()
- self.metrics = register_cache(name, self.cache)
+ self.metrics = register_cache("dictionary", name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -132,9 +132,13 @@ class DictionaryCache(object):
self._update_or_insert(key, value, known_absent)
def _update_or_insert(self, key, value, known_absent):
- entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {}))
+ # We pop and reinsert as we need to tell the cache the size may have
+ # changed
+
+ entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
entry.value.update(value)
entry.known_absent.update(known_absent)
+ self.cache[key] = entry
def _insert(self, key, value, known_absent):
self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 0aa103eecb..ff04c91955 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -52,12 +52,12 @@ class ExpiringCache(object):
self._cache = OrderedDict()
- self.metrics = register_cache(cache_name, self)
-
self.iterable = iterable
self._size_estimate = 0
+ self.metrics = register_cache("expiring", cache_name, self)
+
def start(self):
if not self._expiry_ms:
# Don't bother starting the loop if things never expire
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index f088dd430e..1c5a982094 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -154,14 +154,21 @@ class LruCache(object):
def cache_set(key, value, callbacks=[]):
node = cache.get(key, None)
if node is not None:
- if value != node.value:
+ # We sometimes store large objects, e.g. dicts, which cause
+ # the inequality check to take a long time. So let's only do
+ # the check if we have some callbacks to call.
+ if node.callbacks and value != node.value:
for cb in node.callbacks:
cb()
node.callbacks.clear()
- if size_callback:
- cached_cache_len[0] -= size_callback(node.value)
- cached_cache_len[0] += size_callback(value)
+ # We don't bother to protect this by value != node.value as
+ # generally size_callback will be cheap compared with equality
+ # checks. (For example, taking the size of two dicts is quicker
+ # than comparing them for equality.)
+ if size_callback:
+ cached_cache_len[0] -= size_callback(node.value)
+ cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 00af539880..a8491b42d5 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,8 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from twisted.internet import defer
from synapse.util.async import ObservableDeferred
+from synapse.util.caches import register_cache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
+logger = logging.getLogger(__name__)
class ResponseCache(object):
@@ -24,20 +31,69 @@ class ResponseCache(object):
used rather than trying to compute a new response.
"""
- def __init__(self, hs, timeout_ms=0):
+ def __init__(self, hs, name, timeout_ms=0):
self.pending_result_cache = {} # Requests that haven't finished yet.
self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.
+ self._name = name
+ self._metrics = register_cache(
+ "response_cache", name, self
+ )
+
+ def size(self):
+ return len(self.pending_result_cache)
+
+ def __len__(self):
+ return self.size()
+
def get(self, key):
+ """Look up the given key.
+
+ Can return either a new Deferred (which also doesn't follow the synapse
+ logcontext rules), or, if the request has completed, the actual
+ result. You will probably want to make_deferred_yieldable the result.
+
+ If there is no entry for the key, returns None. It is worth noting that
+ this means there is no way to distinguish a completed result of None
+ from an absent cache entry.
+
+ Args:
+ key (hashable):
+
+ Returns:
+ twisted.internet.defer.Deferred|None|E: None if there is no entry
+ for this key; otherwise either a deferred result or the result
+ itself.
+ """
result = self.pending_result_cache.get(key)
if result is not None:
+ self._metrics.inc_hits()
return result.observe()
else:
+ self._metrics.inc_misses()
return None
def set(self, key, deferred):
+ """Set the entry for the given key to the given deferred.
+
+ *deferred* should run its callbacks in the sentinel logcontext (ie,
+ you should wrap normal synapse deferreds with
+ logcontext.run_in_background).
+
+ Can return either a new Deferred (which also doesn't follow the synapse
+ logcontext rules), or, if *deferred* was already complete, the actual
+ result. You will probably want to make_deferred_yieldable the result.
+
+ Args:
+ key (hashable):
+ deferred (twisted.internet.defer.Deferred[T):
+
+ Returns:
+ twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
+ result.
+ """
result = ObservableDeferred(deferred, consumeErrors=True)
self.pending_result_cache[key] = result
@@ -53,3 +109,52 @@ class ResponseCache(object):
result.addBoth(remove)
return result.observe()
+
+ def wrap(self, key, callback, *args, **kwargs):
+ """Wrap together a *get* and *set* call, taking care of logcontexts
+
+ First looks up the key in the cache, and if it is present makes it
+ follow the synapse logcontext rules and returns it.
+
+ Otherwise, makes a call to *callback(*args, **kwargs)*, which should
+ follow the synapse logcontext rules, and adds the result to the cache.
+
+ Example usage:
+
+ @defer.inlineCallbacks
+ def handle_request(request):
+ # etc
+ defer.returnValue(result)
+
+ result = yield response_cache.wrap(
+ key,
+ handle_request,
+ request,
+ )
+
+ Args:
+ key (hashable): key to get/set in the cache
+
+ callback (callable): function to call if the key is not found in
+ the cache
+
+ *args: positional parameters to pass to the callback, if it is used
+
+ **kwargs: named paramters to pass to the callback, if it is used
+
+ Returns:
+ twisted.internet.defer.Deferred: yieldable result
+ """
+ result = self.get(key)
+ if not result:
+ logger.info("[%s]: no cached result for [%s], calculating new one",
+ self._name, key)
+ d = run_in_background(callback, *args, **kwargs)
+ result = self.set(key, d)
+ elif not isinstance(result, defer.Deferred) or result.called:
+ logger.info("[%s]: using completed cached result for [%s]",
+ self._name, key)
+ else:
+ logger.info("[%s]: using incomplete cached result for [%s]",
+ self._name, key)
+ return make_deferred_yieldable(result)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 941d873ab8..a7fe0397fa 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -38,7 +38,7 @@ class StreamChangeCache(object):
self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
- self.metrics = register_cache(self.name, self._cache)
+ self.metrics = register_cache("cache", self.name, self._cache)
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 90a2608d6f..3380970e4e 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -15,9 +15,9 @@
from twisted.internet import threads, reactor
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
-import Queue
+from six.moves import queue
class BackgroundFileConsumer(object):
@@ -49,7 +49,7 @@ class BackgroundFileConsumer(object):
# Queue of slices of bytes to be written. When producer calls
# unregister a final None is sent.
- self._bytes_queue = Queue.Queue()
+ self._bytes_queue = queue.Queue()
# Deferred that is resolved when finished writing
self._finished_deferred = None
@@ -70,7 +70,9 @@ class BackgroundFileConsumer(object):
self._producer = producer
self.streaming = streaming
- self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
+ self._finished_deferred = run_in_background(
+ threads.deferToThread, self._writer
+ )
if not streaming:
self._producer.resumeProducing()
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 6322f0f55c..f497b51f4a 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -14,6 +14,7 @@
# limitations under the License.
from frozendict import frozendict
+import simplejson as json
def freeze(o):
@@ -49,3 +50,21 @@ def unfreeze(o):
pass
return o
+
+
+def _handle_frozendict(obj):
+ """Helper for EventEncoder. Makes frozendicts serializable by returning
+ the underlying dict
+ """
+ if type(obj) is frozendict:
+ # fishing the protected dict out of the object is a bit nasty,
+ # but we don't really want the overhead of copying the dict.
+ return obj._dict
+ raise TypeError('Object of type %s is not JSON serializable' %
+ obj.__class__.__name__)
+
+
+# A JSONEncoder which is capable of encoding frozendics without barfing
+frozendict_json_encoder = json.JSONEncoder(
+ default=_handle_frozendict,
+)
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 45be47159a..e9f0f292ee 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
import logging
@@ -40,12 +40,15 @@ def create_resource_tree(desired_tree, root_resource):
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
+ # twisted requires all resources to be bytes
+ full_path = full_path.encode("utf-8")
+
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
- for path_seg in full_path.split('/')[1:-1]:
+ for path_seg in full_path.split(b'/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
- child_resource = Resource()
+ child_resource = NoResource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
@@ -57,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
# ===========================
# now attach the actual desired resource
- last_path_seg = full_path.split('/')[-1]
+ last_path_seg = full_path.split(b'/')[-1]
# if there is already a resource here, thieve its children and
# replace it
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index d660ec785b..a58c723403 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -59,8 +59,8 @@ class LoggingContext(object):
__slots__ = [
"previous_context", "name", "ru_stime", "ru_utime",
- "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
- "usage_start", "usage_end",
+ "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
+ "usage_start",
"main_thread", "alive",
"request", "tag",
]
@@ -84,14 +84,15 @@ class LoggingContext(object):
def stop(self):
pass
- def add_database_transaction(self, duration_ms):
+ def add_database_transaction(self, duration_sec):
pass
- def add_database_scheduled(self, sched_ms):
+ def add_database_scheduled(self, sched_sec):
pass
def __nonzero__(self):
return False
+ __bool__ = __nonzero__ # python3
sentinel = Sentinel()
@@ -102,14 +103,16 @@ class LoggingContext(object):
self.ru_utime = 0.
self.db_txn_count = 0
- # ms spent waiting for db txns, excluding scheduling time
- self.db_txn_duration_ms = 0
+ # sec spent waiting for db txns, excluding scheduling time
+ self.db_txn_duration_sec = 0
- # ms spent waiting for db txns to be scheduled
- self.db_sched_duration_ms = 0
+ # sec spent waiting for db txns to be scheduled
+ self.db_sched_duration_sec = 0
+ # If alive has the thread resource usage when the logcontext last
+ # became active.
self.usage_start = None
- self.usage_end = None
+
self.main_thread = threading.current_thread()
self.request = None
self.tag = ""
@@ -158,12 +161,12 @@ class LoggingContext(object):
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
- None to avoid suppressing any exeptions that were thrown.
+ None to avoid suppressing any exceptions that were thrown.
"""
current = self.set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
- logger.debug("Expected logging context %s has been lost", self)
+ logger.warn("Expected logging context %s has been lost", self)
else:
logger.warn(
"Current logging context %s is not expected context %s",
@@ -184,47 +187,61 @@ class LoggingContext(object):
def start(self):
if threading.current_thread() is not self.main_thread:
+ logger.warning("Started logcontext %s on different thread", self)
return
- if self.usage_start and self.usage_end:
- self.ru_utime += self.usage_end.ru_utime - self.usage_start.ru_utime
- self.ru_stime += self.usage_end.ru_stime - self.usage_start.ru_stime
- self.usage_start = None
- self.usage_end = None
-
+ # If we haven't already started record the thread resource usage so
+ # far
if not self.usage_start:
self.usage_start = get_thread_resource_usage()
def stop(self):
if threading.current_thread() is not self.main_thread:
+ logger.warning("Stopped logcontext %s on different thread", self)
return
+ # When we stop, let's record the resource used since we started
if self.usage_start:
- self.usage_end = get_thread_resource_usage()
+ usage_end = get_thread_resource_usage()
+
+ self.ru_utime += usage_end.ru_utime - self.usage_start.ru_utime
+ self.ru_stime += usage_end.ru_stime - self.usage_start.ru_stime
+
+ self.usage_start = None
+ else:
+ logger.warning("Called stop on logcontext %s without calling start", self)
def get_resource_usage(self):
+ """Get CPU time used by this logcontext so far.
+
+ Returns:
+ tuple[float, float]: The user and system CPU usage in seconds
+ """
ru_utime = self.ru_utime
ru_stime = self.ru_stime
- if self.usage_start and threading.current_thread() is self.main_thread:
+ # If we are on the correct thread and we're currently running then we
+ # can include resource usage so far.
+ is_main_thread = threading.current_thread() is self.main_thread
+ if self.alive and self.usage_start and is_main_thread:
current = get_thread_resource_usage()
ru_utime += current.ru_utime - self.usage_start.ru_utime
ru_stime += current.ru_stime - self.usage_start.ru_stime
return ru_utime, ru_stime
- def add_database_transaction(self, duration_ms):
+ def add_database_transaction(self, duration_sec):
self.db_txn_count += 1
- self.db_txn_duration_ms += duration_ms
+ self.db_txn_duration_sec += duration_sec
- def add_database_scheduled(self, sched_ms):
+ def add_database_scheduled(self, sched_sec):
"""Record a use of the database pool
Args:
- sched_ms (int): number of milliseconds it took us to get a
+ sched_sec (float): number of seconds it took us to get a
connection
"""
- self.db_sched_duration_ms += sched_ms
+ self.db_sched_duration_sec += sched_sec
class LoggingContextFilter(logging.Filter):
@@ -278,7 +295,7 @@ class PreserveLoggingContext(object):
context = LoggingContext.set_current_context(self.current_context)
if context != self.new_context:
- logger.debug(
+ logger.warn(
"Unexpected logging context: %s is not %s",
context, self.new_context,
)
@@ -301,31 +318,49 @@ def preserve_fn(f):
def run_in_background(f, *args, **kwargs):
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
- deferred returned by the funtion completes.
+ deferred returned by the function completes.
Useful for wrapping functions that return a deferred which you don't yield
- on.
+ on (for instance because you want to pass it to deferred.gatherResults()).
+
+ Note that if you completely discard the result, you should make sure that
+ `f` doesn't raise any deferred exceptions, otherwise a scary-looking
+ CRITICAL error about an unhandled error will be logged without much
+ indication about where it came from.
"""
current = LoggingContext.current_context()
- res = f(*args, **kwargs)
- if isinstance(res, defer.Deferred) and not res.called:
- # The function will have reset the context before returning, so
- # we need to restore it now.
- LoggingContext.set_current_context(current)
-
- # The original context will be restored when the deferred
- # completes, but there is nothing waiting for it, so it will
- # get leaked into the reactor or some other function which
- # wasn't expecting it. We therefore need to reset the context
- # here.
- #
- # (If this feels asymmetric, consider it this way: we are
- # effectively forking a new thread of execution. We are
- # probably currently within a ``with LoggingContext()`` block,
- # which is supposed to have a single entry and exit point. But
- # by spawning off another deferred, we are effectively
- # adding a new exit point.)
- res.addBoth(_set_context_cb, LoggingContext.sentinel)
+ try:
+ res = f(*args, **kwargs)
+ except: # noqa: E722
+ # the assumption here is that the caller doesn't want to be disturbed
+ # by synchronous exceptions, so let's turn them into Failures.
+ return defer.fail()
+
+ if not isinstance(res, defer.Deferred):
+ return res
+
+ if res.called and not res.paused:
+ # The function should have maintained the logcontext, so we can
+ # optimise out the messing about
+ return res
+
+ # The function may have reset the context before returning, so
+ # we need to restore it now.
+ ctx = LoggingContext.set_current_context(current)
+
+ # The original context will be restored when the deferred
+ # completes, but there is nothing waiting for it, so it will
+ # get leaked into the reactor or some other function which
+ # wasn't expecting it. We therefore need to reset the context
+ # here.
+ #
+ # (If this feels asymmetric, consider it this way: we are
+ # effectively forking a new thread of execution. We are
+ # probably currently within a ``with LoggingContext()`` block,
+ # which is supposed to have a single entry and exit point. But
+ # by spawning off another deferred, we are effectively
+ # adding a new exit point.)
+ res.addBoth(_set_context_cb, ctx)
return res
@@ -340,11 +375,20 @@ def make_deferred_yieldable(deferred):
returning a deferred. Then, when the deferred completes, restores the
current logcontext before running callbacks/errbacks.
- (This is more-or-less the opposite operation to preserve_fn.)
+ (This is more-or-less the opposite operation to run_in_background.)
"""
- if isinstance(deferred, defer.Deferred) and not deferred.called:
- prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
- deferred.addBoth(_set_context_cb, prev_context)
+ if not isinstance(deferred, defer.Deferred):
+ return deferred
+
+ if deferred.called and not deferred.paused:
+ # it looks like this deferred is ready to run any callbacks we give it
+ # immediately. We may as well optimise out the logcontext faffery.
+ return deferred
+
+ # ok, we can't be sure that a yield won't block, so let's reset the
+ # logcontext, and add a callback to the deferred to restore it.
+ prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+ deferred.addBoth(_set_context_cb, prev_context)
return deferred
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
index cdbc4bffd7..3e42868ea9 100644
--- a/synapse/util/logformatter.py
+++ b/synapse/util/logformatter.py
@@ -14,7 +14,7 @@
# limitations under the License.
-import StringIO
+from six import StringIO
import logging
import traceback
@@ -32,7 +32,7 @@ class LogFormatter(logging.Formatter):
super(LogFormatter, self).__init__(*args, **kwargs)
def formatException(self, ei):
- sio = StringIO.StringIO()
+ sio = StringIO()
(typ, val, tb) = ei
# log the stack above the exception capture point if possible, but
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index 3a83828d25..03249c5dc8 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -96,7 +96,7 @@ def time_function(f):
id = _TIME_FUNC_ID
_TIME_FUNC_ID += 1
- start = time.clock() * 1000
+ start = time.clock()
try:
_log_debug_as_f(
@@ -107,10 +107,10 @@ def time_function(f):
r = f(*args, **kwargs)
finally:
- end = time.clock() * 1000
+ end = time.clock()
_log_debug_as_f(
f,
- "[FUNC END] {%s-%d} %f",
+ "[FUNC END] {%s-%d} %.3f sec",
(func_name, id, end - start,),
)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index e4b5687a4b..1ba7d65c7c 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -15,8 +15,8 @@
from twisted.internet import defer
+from prometheus_client import Counter
from synapse.util.logcontext import LoggingContext
-import synapse.metrics
from functools import wraps
import logging
@@ -24,66 +24,26 @@ import logging
logger = logging.getLogger(__name__)
+block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"])
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-# total number of times we have hit this block
-block_counter = metrics.register_counter(
- "block_count",
- labels=["block_name"],
- alternative_names=(
- # the following are all deprecated aliases for the same metric
- metrics.name_prefix + x for x in (
- "_block_timer:count",
- "_block_ru_utime:count",
- "_block_ru_stime:count",
- "_block_db_txn_count:count",
- "_block_db_txn_duration:count",
- )
- )
-)
-
-block_timer = metrics.register_counter(
- "block_time_seconds",
- labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_timer:total",
- ),
-)
-
-block_ru_utime = metrics.register_counter(
- "block_ru_utime_seconds", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_ru_utime:total",
- ),
-)
-
-block_ru_stime = metrics.register_counter(
- "block_ru_stime_seconds", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_ru_stime:total",
- ),
-)
-
-block_db_txn_count = metrics.register_counter(
- "block_db_txn_count", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_db_txn_count:total",
- ),
-)
+block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"])
+
+block_ru_utime = Counter(
+ "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"])
+
+block_ru_stime = Counter(
+ "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"])
+
+block_db_txn_count = Counter(
+ "synapse_util_metrics_block_db_txn_count", "", ["block_name"])
# seconds spent waiting for db txns, excluding scheduling time, in this block
-block_db_txn_duration = metrics.register_counter(
- "block_db_txn_duration_seconds", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_db_txn_duration:total",
- ),
-)
+block_db_txn_duration = Counter(
+ "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"])
# seconds spent waiting for a db connection, in this block
-block_db_sched_duration = metrics.register_counter(
- "block_db_sched_duration_seconds", labels=["block_name"],
-)
+block_db_sched_duration = Counter(
+ "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"])
def measure_func(name):
@@ -102,7 +62,7 @@ class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
"ru_stime",
- "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+ "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
"created_context",
]
@@ -114,7 +74,7 @@ class Measure(object):
self.created_context = False
def __enter__(self):
- self.start = self.clock.time_msec()
+ self.start = self.clock.time()
self.start_context = LoggingContext.current_context()
if not self.start_context:
self.start_context = LoggingContext("Measure")
@@ -123,17 +83,17 @@ class Measure(object):
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count
- self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
- self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
+ self.db_txn_duration_sec = self.start_context.db_txn_duration_sec
+ self.db_sched_duration_sec = self.start_context.db_sched_duration_sec
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context:
return
- duration = self.clock.time_msec() - self.start
+ duration = self.clock.time() - self.start
- block_counter.inc(self.name)
- block_timer.inc_by(duration, self.name)
+ block_counter.labels(self.name).inc()
+ block_timer.labels(self.name).inc(duration)
context = LoggingContext.current_context()
@@ -150,19 +110,13 @@ class Measure(object):
ru_utime, ru_stime = context.get_resource_usage()
- block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
- block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
- block_db_txn_count.inc_by(
- context.db_txn_count - self.db_txn_count, self.name
- )
- block_db_txn_duration.inc_by(
- (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
- self.name
- )
- block_db_sched_duration.inc_by(
- (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
- self.name
- )
+ block_ru_utime.labels(self.name).inc(ru_utime - self.ru_utime)
+ block_ru_stime.labels(self.name).inc(ru_stime - self.ru_stime)
+ block_db_txn_count.labels(self.name).inc(context.db_txn_count - self.db_txn_count)
+ block_db_txn_duration.labels(self.name).inc(
+ context.db_txn_duration_sec - self.db_txn_duration_sec)
+ block_db_sched_duration.labels(self.name).inc(
+ context.db_sched_duration_sec - self.db_sched_duration_sec)
if self.created_context:
self.start_context.__exit__(exc_type, exc_val, exc_tb)
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 1101881a2d..0ab63c3d7d 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,7 +18,10 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import (
+ run_in_background, make_deferred_yieldable,
+ PreserveLoggingContext,
+)
import collections
import contextlib
@@ -150,7 +153,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req",
id(request_id),
)
- ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
+ ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id)
@@ -176,6 +179,9 @@ class _PerHostRatelimiter(object):
return r
def on_err(r):
+ # XXX: why is this necessary? this is called before we start
+ # processing the request so why would the request be in
+ # current_processing?
self.current_processing.discard(request_id)
return r
@@ -187,7 +193,7 @@ class _PerHostRatelimiter(object):
ret_defer.addCallbacks(on_start, on_err)
ret_defer.addBoth(on_both)
- return ret_defer
+ return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id):
logger.debug(
@@ -197,7 +203,12 @@ class _PerHostRatelimiter(object):
self.current_processing.discard(request_id)
try:
request_id, deferred = self.ready_request_queue.popitem()
+
+ # XXX: why do we do the following? the on_start callback above will
+ # do it for us.
self.current_processing.add(request_id)
- deferred.callback(None)
+
+ with PreserveLoggingContext():
+ deferred.callback(None)
except KeyError:
pass
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 47b0bb5eb3..4e93f69d3a 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -203,8 +203,8 @@ class RetryDestinationLimiter(object):
)
except Exception:
logger.exception(
- "Failed to store set_destination_retry_timings",
+ "Failed to store destination_retry_timings",
)
# we deliberately do this in the background.
- synapse.util.logcontext.preserve_fn(store_retry_timings)()
+ synapse.util.logcontext.run_in_background(store_retry_timings)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 95a6168e16..b98b9dc6e4 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -15,6 +15,7 @@
import random
import string
+from six.moves import range
_string_with_symbols = (
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@@ -22,12 +23,12 @@ _string_with_symbols = (
def random_string(length):
- return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
+ return ''.join(random.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length):
return ''.join(
- random.choice(_string_with_symbols) for _ in xrange(length)
+ random.choice(_string_with_symbols) for _ in range(length)
)
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index b70f9a6b0a..7a9e45aca9 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from six.moves import range
+
class _Entry(object):
__slots__ = ["end_key", "queue"]
@@ -68,7 +70,7 @@ class WheelTimer(object):
# Add empty entries between the end of the current list and when we want
# to insert. This ensures there are no gaps.
self.entries.extend(
- _Entry(key) for key in xrange(last_key, then_key + 1)
+ _Entry(key) for key in range(last_key, then_key + 1)
)
self.entries[-1].queue.append(obj)
|