diff options
author | Matthew Hodgson <matthew@matrix.org> | 2016-03-27 22:54:42 +0100 |
---|---|---|
committer | Matthew Hodgson <matthew@matrix.org> | 2016-03-27 22:54:42 +0100 |
commit | d9d48aad2d58deb5db422a5373a4dac9334a0618 (patch) | |
tree | 63e51372ca9ace4971403928bd46440ff9e455e2 /synapse/util | |
parent | initial WIP of a tentative preview_url endpoint - incomplete, untested, exper... (diff) | |
parent | typo (diff) | |
download | synapse-d9d48aad2d58deb5db422a5373a4dac9334a0618.tar.xz |
Merge branch 'develop' into matthew/preview_urls
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/__init__.py | 10 | ||||
-rw-r--r-- | synapse/util/async.py | 11 | ||||
-rw-r--r-- | synapse/util/caches/descriptors.py | 26 | ||||
-rw-r--r-- | synapse/util/caches/expiringcache.py | 20 | ||||
-rw-r--r-- | synapse/util/caches/lrucache.py | 10 | ||||
-rw-r--r-- | synapse/util/caches/snapshot_cache.py | 3 | ||||
-rw-r--r-- | synapse/util/caches/stream_change_cache.py | 123 | ||||
-rw-r--r-- | synapse/util/caches/treecache.py | 38 | ||||
-rw-r--r-- | synapse/util/distributor.py | 15 | ||||
-rw-r--r-- | synapse/util/logcontext.py | 130 | ||||
-rw-r--r-- | synapse/util/logutils.py | 37 | ||||
-rw-r--r-- | synapse/util/metrics.py | 97 | ||||
-rw-r--r-- | synapse/util/ratelimitutils.py | 3 | ||||
-rw-r--r-- | synapse/util/wheel_timer.py | 97 |
14 files changed, 563 insertions, 57 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index f1fe963adf..3b9da5b34a 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +from synapse.util.logcontext import PreserveLoggingContext from twisted.internet import defer, reactor, task @@ -42,11 +42,11 @@ class Clock(object): def time_msec(self): """Returns the current system time in miliseconds since epoch.""" - return self.time() * 1000 + return int(self.time() * 1000) def looping_call(self, f, msec): l = task.LoopingCall(f) - l.start(msec/1000.0, now=False) + l.start(msec / 1000.0, now=False) return l def stop_looping_call(self, loop): @@ -61,10 +61,8 @@ class Clock(object): *args: Postional arguments to pass to function. **kwargs: Key arguments to pass to function. """ - current_context = LoggingContext.current_context() - def wrapped_callback(*args, **kwargs): - with PreserveLoggingContext(current_context): + with PreserveLoggingContext(): callback(*args, **kwargs) with PreserveLoggingContext(): diff --git a/synapse/util/async.py b/synapse/util/async.py index 200edd404c..640fae3890 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,13 +16,16 @@ from twisted.internet import defer, reactor -from .logcontext import preserve_context_over_deferred +from .logcontext import PreserveLoggingContext +@defer.inlineCallbacks def sleep(seconds): d = defer.Deferred() - reactor.callLater(seconds, d.callback, seconds) - return preserve_context_over_deferred(d) + with PreserveLoggingContext(): + reactor.callLater(seconds, d.callback, seconds) + res = yield d + defer.returnValue(res) def run_on_reactor(): @@ -54,6 +57,7 @@ class ObservableDeferred(object): object.__setattr__(self, "_result", (True, r)) while self._observers: try: + # TODO: Handle errors here. self._observers.pop().callback(r) except: pass @@ -63,6 +67,7 @@ class ObservableDeferred(object): object.__setattr__(self, "_result", (False, f)) while self._observers: try: + # TODO: Handle errors here. self._observers.pop().errback(f) except: pass diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 88e56e3302..35544b19fd 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn +) from . import caches_by_name, DEBUG_CACHES, cache_counter @@ -25,6 +28,7 @@ from twisted.internet import defer from collections import OrderedDict +import os import functools import inspect import threading @@ -35,6 +39,9 @@ logger = logging.getLogger(__name__) _CacheSentinel = object() +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) + + class Cache(object): def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): @@ -137,6 +144,8 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, inlineCallbacks=False): + max_entries = int(max_entries * CACHE_SIZE_FACTOR) + self.orig = orig if inlineCallbacks: @@ -149,7 +158,7 @@ class CacheDescriptor(object): self.lru = lru self.tree = tree - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] if len(self.arg_names) < self.num_args: raise Exception( @@ -190,7 +199,7 @@ class CacheDescriptor(object): defer.returnValue(cached_result) observer.addCallback(check_result) - return observer + return preserve_context_over_deferred(observer) except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated @@ -198,6 +207,7 @@ class CacheDescriptor(object): sequence = self.cache.sequence ret = defer.maybeDeferred( + preserve_context_over_fn, self.function_to_call, obj, *args, **kwargs ) @@ -211,7 +221,7 @@ class CacheDescriptor(object): ret = ObservableDeferred(ret, consumeErrors=True) self.cache.update(sequence, cache_key, ret) - return ret.observe() + return preserve_context_over_deferred(ret.observe()) wrapped.invalidate = self.cache.invalidate wrapped.invalidate_all = self.cache.invalidate_all @@ -250,7 +260,7 @@ class CacheListDescriptor(object): self.num_args = num_args self.list_name = list_name - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) self.cache = cache @@ -299,6 +309,7 @@ class CacheListDescriptor(object): args_to_call[self.list_name] = missing ret_d = defer.maybeDeferred( + preserve_context_over_fn, self.function_to_call, **args_to_call ) @@ -308,7 +319,8 @@ class CacheListDescriptor(object): # We need to create deferreds for each arg in the list so that # we can insert the new deferred into the cache. for arg in missing: - observer = ret_d.observe() + with PreserveLoggingContext(): + observer = ret_d.observe() observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) @@ -327,10 +339,10 @@ class CacheListDescriptor(object): cached[arg] = res - return defer.gatherResults( + return preserve_context_over_deferred(defer.gatherResults( cached.values(), consumeErrors=True, - ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) + ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))) obj.__dict__[self.orig.__name__] = wrapped diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 494226f5ea..2b68c1ac93 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.util.caches import cache_counter, caches_by_name + import logging @@ -47,6 +49,8 @@ class ExpiringCache(object): self._cache = {} + caches_by_name[cache_name] = self._cache + def start(self): if not self._expiry_ms: # Don't bother starting the loop if things never expire @@ -55,7 +59,7 @@ class ExpiringCache(object): def f(): self._prune_cache() - self._clock.looping_call(f, self._expiry_ms/2) + self._clock.looping_call(f, self._expiry_ms / 2) def __setitem__(self, key, value): now = self._clock.time_msec() @@ -65,14 +69,19 @@ class ExpiringCache(object): if self._max_len and len(self._cache.keys()) > self._max_len: sorted_entries = sorted( self._cache.items(), - key=lambda (k, v): v.time, + key=lambda item: item[1].time, ) for k, _ in sorted_entries[self._max_len:]: self._cache.pop(k) def __getitem__(self, key): - entry = self._cache[key] + try: + entry = self._cache[key] + cache_counter.inc_hits(self._cache_name) + except KeyError: + cache_counter.inc_misses(self._cache_name) + raise if self._reset_expiry_on_get: entry.time = self._clock.time_msec() @@ -105,9 +114,12 @@ class ExpiringCache(object): logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self._cache.keys()) + self._cache_name, begin_length, len(self._cache) ) + def __len__(self): + return len(self._cache) + class _CacheEntry(object): def __init__(self, time, value): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index e6a66dc041..f7423f2fab 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -37,7 +37,7 @@ class LruCache(object): """ def __init__(self, max_size, keylen=1, cache_type=dict): cache = cache_type() - self.size = 0 + self.cache = cache # Used for introspection. list_root = [] list_root[:] = [list_root, list_root, None, None] @@ -60,7 +60,6 @@ class LruCache(object): prev_node[NEXT] = node next_node[PREV] = node cache[key] = node - self.size += 1 def move_node_to_front(node): prev_node = node[PREV] @@ -79,7 +78,6 @@ class LruCache(object): next_node = node[NEXT] prev_node[NEXT] = next_node next_node[PREV] = prev_node - self.size -= 1 @synchronized def cache_get(key, default=None): @@ -98,7 +96,7 @@ class LruCache(object): node[VALUE] = value else: add_node(key, value) - if self.size > max_size: + if len(cache) > max_size: todelete = list_root[PREV] delete_node(todelete) cache.pop(todelete[KEY], None) @@ -110,7 +108,7 @@ class LruCache(object): return node[VALUE] else: add_node(key, value) - if self.size > max_size: + if len(cache) > max_size: todelete = list_root[PREV] delete_node(todelete) cache.pop(todelete[KEY], None) @@ -145,7 +143,7 @@ class LruCache(object): @synchronized def cache_len(): - return self.size + return len(cache) @synchronized def cache_contains(key): diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py index b1e40417fd..d03678b8c8 100644 --- a/synapse/util/caches/snapshot_cache.py +++ b/synapse/util/caches/snapshot_cache.py @@ -87,7 +87,8 @@ class SnapshotCache(object): # expire from the rotation of that cache. self.next_result_cache[key] = result self.pending_result_cache.pop(key, None) + return r - result.observe().addBoth(shuffle_along) + result.addBoth(shuffle_along) return result.observe() diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py new file mode 100644 index 0000000000..ea8a74ca69 --- /dev/null +++ b/synapse/util/caches/stream_change_cache.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.caches import cache_counter, caches_by_name + + +from blist import sorteddict +import logging +import os + + +logger = logging.getLogger(__name__) + + +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) + + +class StreamChangeCache(object): + """Keeps track of the stream positions of the latest change in a set of entities. + + Typically the entity will be a room or user id. + + Given a list of entities and a stream position, it will give a subset of + entities that may have changed since that position. If position key is too + old then the cache will simply return all given entities. + """ + def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}): + self._max_size = int(max_size * CACHE_SIZE_FACTOR) + self._entity_to_key = {} + self._cache = sorteddict() + self._earliest_known_stream_pos = current_stream_pos + self.name = name + caches_by_name[self.name] = self._cache + + for entity, stream_pos in prefilled_cache.items(): + self.entity_has_changed(entity, stream_pos) + + def has_entity_changed(self, entity, stream_pos): + """Returns True if the entity may have been updated since stream_pos + """ + assert type(stream_pos) is int + + if stream_pos < self._earliest_known_stream_pos: + cache_counter.inc_misses(self.name) + return True + + latest_entity_change_pos = self._entity_to_key.get(entity, None) + if latest_entity_change_pos is None: + cache_counter.inc_hits(self.name) + return False + + if stream_pos < latest_entity_change_pos: + cache_counter.inc_misses(self.name) + return True + + cache_counter.inc_hits(self.name) + return False + + def get_entities_changed(self, entities, stream_pos): + """Returns subset of entities that have had new things since the + given position. If the position is too old it will just return the given list. + """ + assert type(stream_pos) is int + + if stream_pos >= self._earliest_known_stream_pos: + keys = self._cache.keys() + i = keys.bisect_right(stream_pos) + + result = set( + self._cache[k] for k in keys[i:] + ).intersection(entities) + + cache_counter.inc_hits(self.name) + else: + result = entities + cache_counter.inc_misses(self.name) + + return result + + def get_all_entities_changed(self, stream_pos): + """Returns all entites that have had new things since the given + position. If the position is too old it will return None. + """ + assert type(stream_pos) is int + + if stream_pos >= self._earliest_known_stream_pos: + keys = self._cache.keys() + i = keys.bisect_right(stream_pos) + + return [self._cache[k] for k in keys[i:]] + else: + return None + + def entity_has_changed(self, entity, stream_pos): + """Informs the cache that the entity has been changed at the given + position. + """ + assert type(stream_pos) is int + + if stream_pos > self._earliest_known_stream_pos: + old_pos = self._entity_to_key.get(entity, None) + if old_pos is not None: + stream_pos = max(stream_pos, old_pos) + self._cache.pop(old_pos, None) + self._cache[stream_pos] = entity + self._entity_to_key[entity] = stream_pos + + while len(self._cache) > self._max_size: + k, r = self._cache.popitem() + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + self._entity_to_key.pop(r, None) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 3b58860910..03bc1401b7 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -8,6 +8,7 @@ class TreeCache(object): Keys must be tuples. """ def __init__(self): + self.size = 0 self.root = {} def __setitem__(self, key, value): @@ -20,7 +21,8 @@ class TreeCache(object): node = self.root for k in key[:-1]: node = node.setdefault(k, {}) - node[key[-1]] = value + node[key[-1]] = _Entry(value) + self.size += 1 def get(self, key, default=None): node = self.root @@ -28,9 +30,10 @@ class TreeCache(object): node = node.get(k, None) if node is None: return default - return node.get(key[-1], default) + return node.get(key[-1], _Entry(default)).value def clear(self): + self.size = 0 self.root = {} def pop(self, key, default=None): @@ -55,6 +58,35 @@ class TreeCache(object): if n: break - node_and_keys[i+1][0].pop(k) + node_and_keys[i + 1][0].pop(k) + popped, cnt = _strip_and_count_entires(popped) + self.size -= cnt return popped + + def __len__(self): + return self.size + + +class _Entry(object): + __slots__ = ["value"] + + def __init__(self, value): + self.value = value + + +def _strip_and_count_entires(d): + """Takes an _Entry or dict with leaves of _Entry's, and either returns the + value or a dictionary with _Entry's replaced by their values. + + Also returns the count of _Entry's + """ + if isinstance(d, dict): + cnt = 0 + for key, value in d.items(): + v, n = _strip_and_count_entires(value) + d[key] = v + cnt += n + return d, cnt + else: + return d.value, 1 diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 4ebfebf701..8875813de4 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -15,9 +15,7 @@ from twisted.internet import defer -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_context_over_deferred, -) +from synapse.util.logcontext import PreserveLoggingContext from synapse.util import unwrapFirstError @@ -97,6 +95,7 @@ class Signal(object): Each observer callable may return a Deferred.""" self.observers.append(observer) + @defer.inlineCallbacks def fire(self, *args, **kwargs): """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is @@ -116,6 +115,7 @@ class Signal(object): failure.getTracebackObject())) if not self.suppress_failures: return failure + return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) with PreserveLoggingContext(): @@ -124,8 +124,11 @@ class Signal(object): for observer in self.observers ] - d = defer.gatherResults(deferreds, consumeErrors=True) + res = yield defer.gatherResults( + deferreds, consumeErrors=True + ).addErrback(unwrapFirstError) - d.addErrback(unwrapFirstError) + defer.returnValue(res) - return preserve_context_over_deferred(d) + def __repr__(self): + return "<Signal name=%r>" % (self.name,) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 0595c0fa4f..5316259d15 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -41,13 +41,14 @@ except: class LoggingContext(object): """Additional context for log formatting. Contexts are scoped within a - "with" block. Contexts inherit the state of their parent contexts. + "with" block. Args: name (str): Name for the context for debugging. """ __slots__ = [ - "parent_context", "name", "usage_start", "usage_end", "main_thread", "__dict__" + "previous_context", "name", "usage_start", "usage_end", "main_thread", + "__dict__", "tag", "alive", ] thread_local = threading.local() @@ -72,10 +73,13 @@ class LoggingContext(object): def add_database_transaction(self, duration_ms): pass + def __nonzero__(self): + return False + sentinel = Sentinel() def __init__(self, name=None): - self.parent_context = None + self.previous_context = LoggingContext.current_context() self.name = name self.ru_stime = 0. self.ru_utime = 0. @@ -83,6 +87,8 @@ class LoggingContext(object): self.db_txn_duration = 0. self.usage_start = None self.main_thread = threading.current_thread() + self.tag = "" + self.alive = True def __str__(self): return "%s@%x" % (self.name, id(self)) @@ -101,6 +107,7 @@ class LoggingContext(object): The context that was previously active """ current = cls.current_context() + if current is not context: current.stop() cls.thread_local.current_context = context @@ -109,9 +116,13 @@ class LoggingContext(object): def __enter__(self): """Enters this logging context into thread local storage""" - if self.parent_context is not None: - raise Exception("Attempt to enter logging context multiple times") - self.parent_context = self.set_current_context(self) + old_context = self.set_current_context(self) + if self.previous_context != old_context: + logger.warn( + "Expected previous context %r, found %r", + self.previous_context, old_context + ) + self.alive = True return self def __exit__(self, type, value, traceback): @@ -120,7 +131,7 @@ class LoggingContext(object): Returns: None to avoid suppressing any exeptions that were thrown. """ - current = self.set_current_context(self.parent_context) + current = self.set_current_context(self.previous_context) if current is not self: if current is self.sentinel: logger.debug("Expected logging context %s has been lost", self) @@ -130,16 +141,11 @@ class LoggingContext(object): current, self ) - self.parent_context = None - - def __getattr__(self, name): - """Delegate member lookup to parent context""" - return getattr(self.parent_context, name) + self.previous_context = None + self.alive = False def copy_to(self, record): - """Copy fields from this context and its parents to the record""" - if self.parent_context is not None: - self.parent_context.copy_to(record) + """Copy fields from this context to the record""" for key, value in self.__dict__.items(): setattr(record, key, value) @@ -208,7 +214,7 @@ class PreserveLoggingContext(object): exited. Used to restore the context after a function using @defer.inlineCallbacks is resumed by a callback from the reactor.""" - __slots__ = ["current_context", "new_context"] + __slots__ = ["current_context", "new_context", "has_parent"] def __init__(self, new_context=LoggingContext.sentinel): self.new_context = new_context @@ -219,12 +225,27 @@ class PreserveLoggingContext(object): self.new_context ) + if self.current_context: + self.has_parent = self.current_context.previous_context is not None + if not self.current_context.alive: + logger.debug( + "Entering dead context: %s", + self.current_context, + ) + def __exit__(self, type, value, traceback): """Restores the current logging context""" - LoggingContext.set_current_context(self.current_context) + context = LoggingContext.set_current_context(self.current_context) + + if context != self.new_context: + logger.debug( + "Unexpected logging context: %s is not %s", + context, self.new_context, + ) + if self.current_context is not LoggingContext.sentinel: - if self.current_context.parent_context is None: - logger.warn( + if not self.current_context.alive: + logger.debug( "Restoring dead context: %s", self.current_context, ) @@ -284,3 +305,74 @@ def preserve_context_over_deferred(deferred): d = _PreservingContextDeferred(current_context) deferred.chainDeferred(d) return d + + +def preserve_fn(f): + """Ensures that function is called with correct context and that context is + restored after return. Useful for wrapping functions that return a deferred + which you don't yield on. + """ + current = LoggingContext.current_context() + + def g(*args, **kwargs): + with PreserveLoggingContext(current): + return f(*args, **kwargs) + + return g + + +# modules to ignore in `logcontext_tracer` +_to_ignore = [ + "synapse.util.logcontext", + "synapse.http.server", + "synapse.storage._base", + "synapse.util.async", +] + + +def logcontext_tracer(frame, event, arg): + """A tracer that logs whenever a logcontext "unexpectedly" changes within + a function. Probably inaccurate. + + Use by calling `sys.settrace(logcontext_tracer)` in the main thread. + """ + if event == 'call': + name = frame.f_globals["__name__"] + if name.startswith("synapse"): + if name == "synapse.util.logcontext": + if frame.f_code.co_name in ["__enter__", "__exit__"]: + tracer = frame.f_back.f_trace + if tracer: + tracer.just_changed = True + + tracer = frame.f_trace + if tracer: + return tracer + + if not any(name.startswith(ig) for ig in _to_ignore): + return LineTracer() + + +class LineTracer(object): + __slots__ = ["context", "just_changed"] + + def __init__(self): + self.context = LoggingContext.current_context() + self.just_changed = False + + def __call__(self, frame, event, arg): + if event in 'line': + if self.just_changed: + self.context = LoggingContext.current_context() + self.just_changed = False + else: + c = LoggingContext.current_context() + if c != self.context: + logger.info( + "Context changed! %s -> %s, %s, %s", + self.context, c, + frame.f_code.co_filename, frame.f_lineno + ) + self.context = c + + return self diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index d5b1a37eff..3a83828d25 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -111,7 +111,7 @@ def time_function(f): _log_debug_as_f( f, "[FUNC END] {%s-%d} %f", - (func_name, id, end-start,), + (func_name, id, end - start,), ) return r @@ -168,3 +168,38 @@ def trace_function(f): wrapped.__name__ = func_name return wrapped + + +def get_previous_frames(): + s = inspect.currentframe().f_back.f_back + to_return = [] + while s: + if s.f_globals["__name__"].startswith("synapse"): + filename, lineno, function, _, _ = inspect.getframeinfo(s) + args_string = inspect.formatargvalues(*inspect.getargvalues(s)) + + to_return.append("{{ %s:%d %s - Args: %s }}" % ( + filename, lineno, function, args_string + )) + + s = s.f_back + + return ", ". join(to_return) + + +def get_previous_frame(ignore=[]): + s = inspect.currentframe().f_back.f_back + + while s: + if s.f_globals["__name__"].startswith("synapse"): + if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore): + filename, lineno, function, _, _ = inspect.getframeinfo(s) + args_string = inspect.formatargvalues(*inspect.getargvalues(s)) + + return "{{ %s:%d %s - Args: %s }}" % ( + filename, lineno, function, args_string + ) + + s = s.f_back + + return None diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py new file mode 100644 index 0000000000..c51b641125 --- /dev/null +++ b/synapse/util/metrics.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.util.logcontext import LoggingContext +import synapse.metrics + +import logging + + +logger = logging.getLogger(__name__) + + +metrics = synapse.metrics.get_metrics_for(__name__) + +block_timer = metrics.register_distribution( + "block_timer", + labels=["block_name"] +) + +block_ru_utime = metrics.register_distribution( + "block_ru_utime", labels=["block_name"] +) + +block_ru_stime = metrics.register_distribution( + "block_ru_stime", labels=["block_name"] +) + +block_db_txn_count = metrics.register_distribution( + "block_db_txn_count", labels=["block_name"] +) + +block_db_txn_duration = metrics.register_distribution( + "block_db_txn_duration", labels=["block_name"] +) + + +class Measure(object): + __slots__ = [ + "clock", "name", "start_context", "start", "new_context", "ru_utime", + "ru_stime", "db_txn_count", "db_txn_duration" + ] + + def __init__(self, clock, name): + self.clock = clock + self.name = name + self.start_context = None + self.start = None + + def __enter__(self): + self.start = self.clock.time_msec() + self.start_context = LoggingContext.current_context() + if self.start_context: + self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() + self.db_txn_count = self.start_context.db_txn_count + self.db_txn_duration = self.start_context.db_txn_duration + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None or not self.start_context: + return + + duration = self.clock.time_msec() - self.start + block_timer.inc_by(duration, self.name) + + context = LoggingContext.current_context() + + if context != self.start_context: + logger.warn( + "Context have unexpectedly changed from '%s' to '%s'. (%r)", + context, self.start_context, self.name + ) + return + + if not context: + logger.warn("Expected context. (%r)", self.name) + return + + ru_utime, ru_stime = context.get_resource_usage() + + block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name) + block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name) + block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name) + block_db_txn_duration.inc_by( + context.db_txn_duration - self.db_txn_duration, self.name + ) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index c37d6f12e3..4076eed269 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.errors import LimitExceededError from synapse.util.async import sleep +from synapse.util.logcontext import preserve_fn import collections import contextlib @@ -163,7 +164,7 @@ class _PerHostRatelimiter(object): "Ratelimit [%s]: sleeping req", id(request_id), ) - ret_defer = sleep(self.sleep_msec/1000.0) + ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0) self.sleeping_requests.add(request_id) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py new file mode 100644 index 0000000000..7412fc57a4 --- /dev/null +++ b/synapse/util/wheel_timer.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class _Entry(object): + __slots__ = ["end_key", "queue"] + + def __init__(self, end_key): + self.end_key = end_key + self.queue = [] + + +class WheelTimer(object): + """Stores arbitrary objects that will be returned after their timers have + expired. + """ + + def __init__(self, bucket_size=5000): + """ + Args: + bucket_size (int): Size of buckets in ms. Corresponds roughly to the + accuracy of the timer. + """ + self.bucket_size = bucket_size + self.entries = [] + self.current_tick = 0 + + def insert(self, now, obj, then): + """Inserts object into timer. + + Args: + now (int): Current time in msec + obj (object): Object to be inserted + then (int): When to return the object strictly after. + """ + then_key = int(then / self.bucket_size) + 1 + + if self.entries: + min_key = self.entries[0].end_key + max_key = self.entries[-1].end_key + + if then_key <= max_key: + # The max here is to protect against inserts for times in the past + self.entries[max(min_key, then_key) - min_key].queue.append(obj) + return + + next_key = int(now / self.bucket_size) + 1 + if self.entries: + last_key = self.entries[-1].end_key + else: + last_key = next_key + + # Handle the case when `then` is in the past and `entries` is empty. + then_key = max(last_key, then_key) + + # Add empty entries between the end of the current list and when we want + # to insert. This ensures there are no gaps. + self.entries.extend( + _Entry(key) for key in xrange(last_key, then_key + 1) + ) + + self.entries[-1].queue.append(obj) + + def fetch(self, now): + """Fetch any objects that have timed out + + Args: + now (ms): Current time in msec + + Returns: + list: List of objects that have timed out + """ + now_key = int(now / self.bucket_size) + + ret = [] + while self.entries and self.entries[0].end_key <= now_key: + ret.extend(self.entries.pop(0).queue) + + return ret + + def __len__(self): + l = 0 + for entry in self.entries: + l += len(entry.queue) + return l |