diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index f1fe963adf..3b9da5b34a 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -42,11 +42,11 @@ class Clock(object):
def time_msec(self):
"""Returns the current system time in miliseconds since epoch."""
- return self.time() * 1000
+ return int(self.time() * 1000)
def looping_call(self, f, msec):
l = task.LoopingCall(f)
- l.start(msec/1000.0, now=False)
+ l.start(msec / 1000.0, now=False)
return l
def stop_looping_call(self, loop):
@@ -61,10 +61,8 @@ class Clock(object):
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
- current_context = LoggingContext.current_context()
-
def wrapped_callback(*args, **kwargs):
- with PreserveLoggingContext(current_context):
+ with PreserveLoggingContext():
callback(*args, **kwargs)
with PreserveLoggingContext():
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 200edd404c..640fae3890 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,13 +16,16 @@
from twisted.internet import defer, reactor
-from .logcontext import preserve_context_over_deferred
+from .logcontext import PreserveLoggingContext
+@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
- reactor.callLater(seconds, d.callback, seconds)
- return preserve_context_over_deferred(d)
+ with PreserveLoggingContext():
+ reactor.callLater(seconds, d.callback, seconds)
+ res = yield d
+ defer.returnValue(res)
def run_on_reactor():
@@ -54,6 +57,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (True, r))
while self._observers:
try:
+ # TODO: Handle errors here.
self._observers.pop().callback(r)
except:
pass
@@ -63,6 +67,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (False, f))
while self._observers:
try:
+ # TODO: Handle errors here.
self._observers.pop().errback(f)
except:
pass
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 88e56e3302..35544b19fd 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
+)
from . import caches_by_name, DEBUG_CACHES, cache_counter
@@ -25,6 +28,7 @@ from twisted.internet import defer
from collections import OrderedDict
+import os
import functools
import inspect
import threading
@@ -35,6 +39,9 @@ logger = logging.getLogger(__name__)
_CacheSentinel = object()
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+
+
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
@@ -137,6 +144,8 @@ class CacheDescriptor(object):
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
inlineCallbacks=False):
+ max_entries = int(max_entries * CACHE_SIZE_FACTOR)
+
self.orig = orig
if inlineCallbacks:
@@ -149,7 +158,7 @@ class CacheDescriptor(object):
self.lru = lru
self.tree = tree
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+ self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
if len(self.arg_names) < self.num_args:
raise Exception(
@@ -190,7 +199,7 @@ class CacheDescriptor(object):
defer.returnValue(cached_result)
observer.addCallback(check_result)
- return observer
+ return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
@@ -198,6 +207,7 @@ class CacheDescriptor(object):
sequence = self.cache.sequence
ret = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
obj, *args, **kwargs
)
@@ -211,7 +221,7 @@ class CacheDescriptor(object):
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
- return ret.observe()
+ return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
@@ -250,7 +260,7 @@ class CacheListDescriptor(object):
self.num_args = num_args
self.list_name = list_name
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+ self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
@@ -299,6 +309,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
**args_to_call
)
@@ -308,7 +319,8 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
- observer = ret_d.observe()
+ with PreserveLoggingContext():
+ observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
@@ -327,10 +339,10 @@ class CacheListDescriptor(object):
cached[arg] = res
- return defer.gatherResults(
+ return preserve_context_over_deferred(defer.gatherResults(
cached.values(),
consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+ ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
obj.__dict__[self.orig.__name__] = wrapped
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 494226f5ea..2b68c1ac93 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.util.caches import cache_counter, caches_by_name
+
import logging
@@ -47,6 +49,8 @@ class ExpiringCache(object):
self._cache = {}
+ caches_by_name[cache_name] = self._cache
+
def start(self):
if not self._expiry_ms:
# Don't bother starting the loop if things never expire
@@ -55,7 +59,7 @@ class ExpiringCache(object):
def f():
self._prune_cache()
- self._clock.looping_call(f, self._expiry_ms/2)
+ self._clock.looping_call(f, self._expiry_ms / 2)
def __setitem__(self, key, value):
now = self._clock.time_msec()
@@ -65,14 +69,19 @@ class ExpiringCache(object):
if self._max_len and len(self._cache.keys()) > self._max_len:
sorted_entries = sorted(
self._cache.items(),
- key=lambda (k, v): v.time,
+ key=lambda item: item[1].time,
)
for k, _ in sorted_entries[self._max_len:]:
self._cache.pop(k)
def __getitem__(self, key):
- entry = self._cache[key]
+ try:
+ entry = self._cache[key]
+ cache_counter.inc_hits(self._cache_name)
+ except KeyError:
+ cache_counter.inc_misses(self._cache_name)
+ raise
if self._reset_expiry_on_get:
entry.time = self._clock.time_msec()
@@ -105,9 +114,12 @@ class ExpiringCache(object):
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
- self._cache_name, begin_length, len(self._cache.keys())
+ self._cache_name, begin_length, len(self._cache)
)
+ def __len__(self):
+ return len(self._cache)
+
class _CacheEntry(object):
def __init__(self, time, value):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index e6a66dc041..f7423f2fab 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -37,7 +37,7 @@ class LruCache(object):
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type()
- self.size = 0
+ self.cache = cache # Used for introspection.
list_root = []
list_root[:] = [list_root, list_root, None, None]
@@ -60,7 +60,6 @@ class LruCache(object):
prev_node[NEXT] = node
next_node[PREV] = node
cache[key] = node
- self.size += 1
def move_node_to_front(node):
prev_node = node[PREV]
@@ -79,7 +78,6 @@ class LruCache(object):
next_node = node[NEXT]
prev_node[NEXT] = next_node
next_node[PREV] = prev_node
- self.size -= 1
@synchronized
def cache_get(key, default=None):
@@ -98,7 +96,7 @@ class LruCache(object):
node[VALUE] = value
else:
add_node(key, value)
- if self.size > max_size:
+ if len(cache) > max_size:
todelete = list_root[PREV]
delete_node(todelete)
cache.pop(todelete[KEY], None)
@@ -110,7 +108,7 @@ class LruCache(object):
return node[VALUE]
else:
add_node(key, value)
- if self.size > max_size:
+ if len(cache) > max_size:
todelete = list_root[PREV]
delete_node(todelete)
cache.pop(todelete[KEY], None)
@@ -145,7 +143,7 @@ class LruCache(object):
@synchronized
def cache_len():
- return self.size
+ return len(cache)
@synchronized
def cache_contains(key):
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
index b1e40417fd..d03678b8c8 100644
--- a/synapse/util/caches/snapshot_cache.py
+++ b/synapse/util/caches/snapshot_cache.py
@@ -87,7 +87,8 @@ class SnapshotCache(object):
# expire from the rotation of that cache.
self.next_result_cache[key] = result
self.pending_result_cache.pop(key, None)
+ return r
- result.observe().addBoth(shuffle_along)
+ result.addBoth(shuffle_along)
return result.observe()
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
new file mode 100644
index 0000000000..ea8a74ca69
--- /dev/null
+++ b/synapse/util/caches/stream_change_cache.py
@@ -0,0 +1,123 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.caches import cache_counter, caches_by_name
+
+
+from blist import sorteddict
+import logging
+import os
+
+
+logger = logging.getLogger(__name__)
+
+
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+
+
+class StreamChangeCache(object):
+ """Keeps track of the stream positions of the latest change in a set of entities.
+
+ Typically the entity will be a room or user id.
+
+ Given a list of entities and a stream position, it will give a subset of
+ entities that may have changed since that position. If position key is too
+ old then the cache will simply return all given entities.
+ """
+ def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
+ self._max_size = int(max_size * CACHE_SIZE_FACTOR)
+ self._entity_to_key = {}
+ self._cache = sorteddict()
+ self._earliest_known_stream_pos = current_stream_pos
+ self.name = name
+ caches_by_name[self.name] = self._cache
+
+ for entity, stream_pos in prefilled_cache.items():
+ self.entity_has_changed(entity, stream_pos)
+
+ def has_entity_changed(self, entity, stream_pos):
+ """Returns True if the entity may have been updated since stream_pos
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos < self._earliest_known_stream_pos:
+ cache_counter.inc_misses(self.name)
+ return True
+
+ latest_entity_change_pos = self._entity_to_key.get(entity, None)
+ if latest_entity_change_pos is None:
+ cache_counter.inc_hits(self.name)
+ return False
+
+ if stream_pos < latest_entity_change_pos:
+ cache_counter.inc_misses(self.name)
+ return True
+
+ cache_counter.inc_hits(self.name)
+ return False
+
+ def get_entities_changed(self, entities, stream_pos):
+ """Returns subset of entities that have had new things since the
+ given position. If the position is too old it will just return the given list.
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos >= self._earliest_known_stream_pos:
+ keys = self._cache.keys()
+ i = keys.bisect_right(stream_pos)
+
+ result = set(
+ self._cache[k] for k in keys[i:]
+ ).intersection(entities)
+
+ cache_counter.inc_hits(self.name)
+ else:
+ result = entities
+ cache_counter.inc_misses(self.name)
+
+ return result
+
+ def get_all_entities_changed(self, stream_pos):
+ """Returns all entites that have had new things since the given
+ position. If the position is too old it will return None.
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos >= self._earliest_known_stream_pos:
+ keys = self._cache.keys()
+ i = keys.bisect_right(stream_pos)
+
+ return [self._cache[k] for k in keys[i:]]
+ else:
+ return None
+
+ def entity_has_changed(self, entity, stream_pos):
+ """Informs the cache that the entity has been changed at the given
+ position.
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos > self._earliest_known_stream_pos:
+ old_pos = self._entity_to_key.get(entity, None)
+ if old_pos is not None:
+ stream_pos = max(stream_pos, old_pos)
+ self._cache.pop(old_pos, None)
+ self._cache[stream_pos] = entity
+ self._entity_to_key[entity] = stream_pos
+
+ while len(self._cache) > self._max_size:
+ k, r = self._cache.popitem()
+ self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
+ self._entity_to_key.pop(r, None)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 3b58860910..03bc1401b7 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -8,6 +8,7 @@ class TreeCache(object):
Keys must be tuples.
"""
def __init__(self):
+ self.size = 0
self.root = {}
def __setitem__(self, key, value):
@@ -20,7 +21,8 @@ class TreeCache(object):
node = self.root
for k in key[:-1]:
node = node.setdefault(k, {})
- node[key[-1]] = value
+ node[key[-1]] = _Entry(value)
+ self.size += 1
def get(self, key, default=None):
node = self.root
@@ -28,9 +30,10 @@ class TreeCache(object):
node = node.get(k, None)
if node is None:
return default
- return node.get(key[-1], default)
+ return node.get(key[-1], _Entry(default)).value
def clear(self):
+ self.size = 0
self.root = {}
def pop(self, key, default=None):
@@ -55,6 +58,35 @@ class TreeCache(object):
if n:
break
- node_and_keys[i+1][0].pop(k)
+ node_and_keys[i + 1][0].pop(k)
+ popped, cnt = _strip_and_count_entires(popped)
+ self.size -= cnt
return popped
+
+ def __len__(self):
+ return self.size
+
+
+class _Entry(object):
+ __slots__ = ["value"]
+
+ def __init__(self, value):
+ self.value = value
+
+
+def _strip_and_count_entires(d):
+ """Takes an _Entry or dict with leaves of _Entry's, and either returns the
+ value or a dictionary with _Entry's replaced by their values.
+
+ Also returns the count of _Entry's
+ """
+ if isinstance(d, dict):
+ cnt = 0
+ for key, value in d.items():
+ v, n = _strip_and_count_entires(value)
+ d[key] = v
+ cnt += n
+ return d, cnt
+ else:
+ return d.value, 1
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 4ebfebf701..8875813de4 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -15,9 +15,7 @@
from twisted.internet import defer
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_context_over_deferred,
-)
+from synapse.util.logcontext import PreserveLoggingContext
from synapse.util import unwrapFirstError
@@ -97,6 +95,7 @@ class Signal(object):
Each observer callable may return a Deferred."""
self.observers.append(observer)
+ @defer.inlineCallbacks
def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -116,6 +115,7 @@ class Signal(object):
failure.getTracebackObject()))
if not self.suppress_failures:
return failure
+
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
with PreserveLoggingContext():
@@ -124,8 +124,11 @@ class Signal(object):
for observer in self.observers
]
- d = defer.gatherResults(deferreds, consumeErrors=True)
+ res = yield defer.gatherResults(
+ deferreds, consumeErrors=True
+ ).addErrback(unwrapFirstError)
- d.addErrback(unwrapFirstError)
+ defer.returnValue(res)
- return preserve_context_over_deferred(d)
+ def __repr__(self):
+ return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 0595c0fa4f..5316259d15 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -41,13 +41,14 @@ except:
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
- "with" block. Contexts inherit the state of their parent contexts.
+ "with" block.
Args:
name (str): Name for the context for debugging.
"""
__slots__ = [
- "parent_context", "name", "usage_start", "usage_end", "main_thread", "__dict__"
+ "previous_context", "name", "usage_start", "usage_end", "main_thread",
+ "__dict__", "tag", "alive",
]
thread_local = threading.local()
@@ -72,10 +73,13 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
pass
+ def __nonzero__(self):
+ return False
+
sentinel = Sentinel()
def __init__(self, name=None):
- self.parent_context = None
+ self.previous_context = LoggingContext.current_context()
self.name = name
self.ru_stime = 0.
self.ru_utime = 0.
@@ -83,6 +87,8 @@ class LoggingContext(object):
self.db_txn_duration = 0.
self.usage_start = None
self.main_thread = threading.current_thread()
+ self.tag = ""
+ self.alive = True
def __str__(self):
return "%s@%x" % (self.name, id(self))
@@ -101,6 +107,7 @@ class LoggingContext(object):
The context that was previously active
"""
current = cls.current_context()
+
if current is not context:
current.stop()
cls.thread_local.current_context = context
@@ -109,9 +116,13 @@ class LoggingContext(object):
def __enter__(self):
"""Enters this logging context into thread local storage"""
- if self.parent_context is not None:
- raise Exception("Attempt to enter logging context multiple times")
- self.parent_context = self.set_current_context(self)
+ old_context = self.set_current_context(self)
+ if self.previous_context != old_context:
+ logger.warn(
+ "Expected previous context %r, found %r",
+ self.previous_context, old_context
+ )
+ self.alive = True
return self
def __exit__(self, type, value, traceback):
@@ -120,7 +131,7 @@ class LoggingContext(object):
Returns:
None to avoid suppressing any exeptions that were thrown.
"""
- current = self.set_current_context(self.parent_context)
+ current = self.set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
logger.debug("Expected logging context %s has been lost", self)
@@ -130,16 +141,11 @@ class LoggingContext(object):
current,
self
)
- self.parent_context = None
-
- def __getattr__(self, name):
- """Delegate member lookup to parent context"""
- return getattr(self.parent_context, name)
+ self.previous_context = None
+ self.alive = False
def copy_to(self, record):
- """Copy fields from this context and its parents to the record"""
- if self.parent_context is not None:
- self.parent_context.copy_to(record)
+ """Copy fields from this context to the record"""
for key, value in self.__dict__.items():
setattr(record, key, value)
@@ -208,7 +214,7 @@ class PreserveLoggingContext(object):
exited. Used to restore the context after a function using
@defer.inlineCallbacks is resumed by a callback from the reactor."""
- __slots__ = ["current_context", "new_context"]
+ __slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context=LoggingContext.sentinel):
self.new_context = new_context
@@ -219,12 +225,27 @@ class PreserveLoggingContext(object):
self.new_context
)
+ if self.current_context:
+ self.has_parent = self.current_context.previous_context is not None
+ if not self.current_context.alive:
+ logger.debug(
+ "Entering dead context: %s",
+ self.current_context,
+ )
+
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
- LoggingContext.set_current_context(self.current_context)
+ context = LoggingContext.set_current_context(self.current_context)
+
+ if context != self.new_context:
+ logger.debug(
+ "Unexpected logging context: %s is not %s",
+ context, self.new_context,
+ )
+
if self.current_context is not LoggingContext.sentinel:
- if self.current_context.parent_context is None:
- logger.warn(
+ if not self.current_context.alive:
+ logger.debug(
"Restoring dead context: %s",
self.current_context,
)
@@ -284,3 +305,74 @@ def preserve_context_over_deferred(deferred):
d = _PreservingContextDeferred(current_context)
deferred.chainDeferred(d)
return d
+
+
+def preserve_fn(f):
+ """Ensures that function is called with correct context and that context is
+ restored after return. Useful for wrapping functions that return a deferred
+ which you don't yield on.
+ """
+ current = LoggingContext.current_context()
+
+ def g(*args, **kwargs):
+ with PreserveLoggingContext(current):
+ return f(*args, **kwargs)
+
+ return g
+
+
+# modules to ignore in `logcontext_tracer`
+_to_ignore = [
+ "synapse.util.logcontext",
+ "synapse.http.server",
+ "synapse.storage._base",
+ "synapse.util.async",
+]
+
+
+def logcontext_tracer(frame, event, arg):
+ """A tracer that logs whenever a logcontext "unexpectedly" changes within
+ a function. Probably inaccurate.
+
+ Use by calling `sys.settrace(logcontext_tracer)` in the main thread.
+ """
+ if event == 'call':
+ name = frame.f_globals["__name__"]
+ if name.startswith("synapse"):
+ if name == "synapse.util.logcontext":
+ if frame.f_code.co_name in ["__enter__", "__exit__"]:
+ tracer = frame.f_back.f_trace
+ if tracer:
+ tracer.just_changed = True
+
+ tracer = frame.f_trace
+ if tracer:
+ return tracer
+
+ if not any(name.startswith(ig) for ig in _to_ignore):
+ return LineTracer()
+
+
+class LineTracer(object):
+ __slots__ = ["context", "just_changed"]
+
+ def __init__(self):
+ self.context = LoggingContext.current_context()
+ self.just_changed = False
+
+ def __call__(self, frame, event, arg):
+ if event in 'line':
+ if self.just_changed:
+ self.context = LoggingContext.current_context()
+ self.just_changed = False
+ else:
+ c = LoggingContext.current_context()
+ if c != self.context:
+ logger.info(
+ "Context changed! %s -> %s, %s, %s",
+ self.context, c,
+ frame.f_code.co_filename, frame.f_lineno
+ )
+ self.context = c
+
+ return self
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index d5b1a37eff..3a83828d25 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -111,7 +111,7 @@ def time_function(f):
_log_debug_as_f(
f,
"[FUNC END] {%s-%d} %f",
- (func_name, id, end-start,),
+ (func_name, id, end - start,),
)
return r
@@ -168,3 +168,38 @@ def trace_function(f):
wrapped.__name__ = func_name
return wrapped
+
+
+def get_previous_frames():
+ s = inspect.currentframe().f_back.f_back
+ to_return = []
+ while s:
+ if s.f_globals["__name__"].startswith("synapse"):
+ filename, lineno, function, _, _ = inspect.getframeinfo(s)
+ args_string = inspect.formatargvalues(*inspect.getargvalues(s))
+
+ to_return.append("{{ %s:%d %s - Args: %s }}" % (
+ filename, lineno, function, args_string
+ ))
+
+ s = s.f_back
+
+ return ", ". join(to_return)
+
+
+def get_previous_frame(ignore=[]):
+ s = inspect.currentframe().f_back.f_back
+
+ while s:
+ if s.f_globals["__name__"].startswith("synapse"):
+ if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
+ filename, lineno, function, _, _ = inspect.getframeinfo(s)
+ args_string = inspect.formatargvalues(*inspect.getargvalues(s))
+
+ return "{{ %s:%d %s - Args: %s }}" % (
+ filename, lineno, function, args_string
+ )
+
+ s = s.f_back
+
+ return None
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
new file mode 100644
index 0000000000..c51b641125
--- /dev/null
+++ b/synapse/util/metrics.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.util.logcontext import LoggingContext
+import synapse.metrics
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+block_timer = metrics.register_distribution(
+ "block_timer",
+ labels=["block_name"]
+)
+
+block_ru_utime = metrics.register_distribution(
+ "block_ru_utime", labels=["block_name"]
+)
+
+block_ru_stime = metrics.register_distribution(
+ "block_ru_stime", labels=["block_name"]
+)
+
+block_db_txn_count = metrics.register_distribution(
+ "block_db_txn_count", labels=["block_name"]
+)
+
+block_db_txn_duration = metrics.register_distribution(
+ "block_db_txn_duration", labels=["block_name"]
+)
+
+
+class Measure(object):
+ __slots__ = [
+ "clock", "name", "start_context", "start", "new_context", "ru_utime",
+ "ru_stime", "db_txn_count", "db_txn_duration"
+ ]
+
+ def __init__(self, clock, name):
+ self.clock = clock
+ self.name = name
+ self.start_context = None
+ self.start = None
+
+ def __enter__(self):
+ self.start = self.clock.time_msec()
+ self.start_context = LoggingContext.current_context()
+ if self.start_context:
+ self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
+ self.db_txn_count = self.start_context.db_txn_count
+ self.db_txn_duration = self.start_context.db_txn_duration
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if exc_type is not None or not self.start_context:
+ return
+
+ duration = self.clock.time_msec() - self.start
+ block_timer.inc_by(duration, self.name)
+
+ context = LoggingContext.current_context()
+
+ if context != self.start_context:
+ logger.warn(
+ "Context have unexpectedly changed from '%s' to '%s'. (%r)",
+ context, self.start_context, self.name
+ )
+ return
+
+ if not context:
+ logger.warn("Expected context. (%r)", self.name)
+ return
+
+ ru_utime, ru_stime = context.get_resource_usage()
+
+ block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
+ block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
+ block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name)
+ block_db_txn_duration.inc_by(
+ context.db_txn_duration - self.db_txn_duration, self.name
+ )
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index c37d6f12e3..4076eed269 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
+from synapse.util.logcontext import preserve_fn
import collections
import contextlib
@@ -163,7 +164,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req",
id(request_id),
)
- ret_defer = sleep(self.sleep_msec/1000.0)
+ ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id)
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
new file mode 100644
index 0000000000..7412fc57a4
--- /dev/null
+++ b/synapse/util/wheel_timer.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class _Entry(object):
+ __slots__ = ["end_key", "queue"]
+
+ def __init__(self, end_key):
+ self.end_key = end_key
+ self.queue = []
+
+
+class WheelTimer(object):
+ """Stores arbitrary objects that will be returned after their timers have
+ expired.
+ """
+
+ def __init__(self, bucket_size=5000):
+ """
+ Args:
+ bucket_size (int): Size of buckets in ms. Corresponds roughly to the
+ accuracy of the timer.
+ """
+ self.bucket_size = bucket_size
+ self.entries = []
+ self.current_tick = 0
+
+ def insert(self, now, obj, then):
+ """Inserts object into timer.
+
+ Args:
+ now (int): Current time in msec
+ obj (object): Object to be inserted
+ then (int): When to return the object strictly after.
+ """
+ then_key = int(then / self.bucket_size) + 1
+
+ if self.entries:
+ min_key = self.entries[0].end_key
+ max_key = self.entries[-1].end_key
+
+ if then_key <= max_key:
+ # The max here is to protect against inserts for times in the past
+ self.entries[max(min_key, then_key) - min_key].queue.append(obj)
+ return
+
+ next_key = int(now / self.bucket_size) + 1
+ if self.entries:
+ last_key = self.entries[-1].end_key
+ else:
+ last_key = next_key
+
+ # Handle the case when `then` is in the past and `entries` is empty.
+ then_key = max(last_key, then_key)
+
+ # Add empty entries between the end of the current list and when we want
+ # to insert. This ensures there are no gaps.
+ self.entries.extend(
+ _Entry(key) for key in xrange(last_key, then_key + 1)
+ )
+
+ self.entries[-1].queue.append(obj)
+
+ def fetch(self, now):
+ """Fetch any objects that have timed out
+
+ Args:
+ now (ms): Current time in msec
+
+ Returns:
+ list: List of objects that have timed out
+ """
+ now_key = int(now / self.bucket_size)
+
+ ret = []
+ while self.entries and self.entries[0].end_key <= now_key:
+ ret.extend(self.entries.pop(0).queue)
+
+ return ret
+
+ def __len__(self):
+ l = 0
+ for entry in self.entries:
+ l += len(entry.queue)
+ return l
|