diff options
author | Michael Telatynski <7t3chguy@gmail.com> | 2018-07-24 17:17:46 +0100 |
---|---|---|
committer | Michael Telatynski <7t3chguy@gmail.com> | 2018-07-24 17:17:46 +0100 |
commit | 87951d3891efb5bccedf72c12b3da0d6ab482253 (patch) | |
tree | de7d997567c66c5a4d8743c1f3b9d6b474f5cfd9 /synapse/util | |
parent | if inviter_display_name == ""||None then default to inviter MXID (diff) | |
parent | Merge pull request #3595 from matrix-org/erikj/use_deltas (diff) | |
download | synapse-87951d3891efb5bccedf72c12b3da0d6ab482253.tar.xz |
Merge branch 'develop' of github.com:matrix-org/synapse into t3chguy/default_inviter_display_name_3pid
Diffstat (limited to 'synapse/util')
28 files changed, 1295 insertions, 559 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 2a2360ab5d..680ea928c7 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,20 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError -from synapse.util.logcontext import PreserveLoggingContext - -from twisted.internet import defer, reactor, task - -import time import logging +from itertools import islice -logger = logging.getLogger(__name__) +import attr +from twisted.internet import defer, task -class DeferredTimedOutError(SynapseError): - def __init__(self): - super(DeferredTimedOutError, self).__init__(504, "Timed out") +from synapse.util.logcontext import PreserveLoggingContext + +logger = logging.getLogger(__name__) def unwrapFirstError(failure): @@ -35,16 +31,27 @@ def unwrapFirstError(failure): return failure.value.subFailure +@attr.s class Clock(object): - """A small utility that obtains current time-of-day so that time may be - mocked during unit-tests. + """ + A Clock wraps a Twisted reactor and provides utilities on top of it. - TODO(paul): Also move the sleep() functionality into it + Args: + reactor: The Twisted reactor to use. """ + _reactor = attr.ib() + + @defer.inlineCallbacks + def sleep(self, seconds): + d = defer.Deferred() + with PreserveLoggingContext(): + self._reactor.callLater(seconds, d.callback, seconds) + res = yield d + defer.returnValue(res) def time(self): """Returns the current system time in seconds since epoch.""" - return time.time() + return self._reactor.seconds() def time_msec(self): """Returns the current system time in miliseconds since epoch.""" @@ -59,9 +66,10 @@ class Clock(object): f(function): The function to call repeatedly. msec(float): How long to wait between calls in milliseconds. """ - l = task.LoopingCall(f) - l.start(msec / 1000.0, now=False) - return l + call = task.LoopingCall(f) + call.clock = self._reactor + call.start(msec / 1000.0, now=False) + return call def call_later(self, delay, callback, *args, **kwargs): """Call something later @@ -77,61 +85,27 @@ class Clock(object): callback(*args, **kwargs) with PreserveLoggingContext(): - return reactor.callLater(delay, wrapped_callback, *args, **kwargs) + return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) def cancel_call_later(self, timer, ignore_errs=False): try: timer.cancel() - except: + except Exception: if not ignore_errs: raise - def time_bound_deferred(self, given_deferred, time_out): - if given_deferred.called: - return given_deferred - - ret_deferred = defer.Deferred() - def timed_out_fn(): - e = DeferredTimedOutError() +def batch_iter(iterable, size): + """batch an iterable up into tuples with a maximum size - try: - ret_deferred.errback(e) - except: - pass + Args: + iterable (iterable): the iterable to slice + size (int): the maximum batch size - try: - given_deferred.cancel() - except: - pass - - timer = None - - def cancel(res): - try: - self.cancel_call_later(timer) - except: - pass - return res - - ret_deferred.addBoth(cancel) - - def success(res): - try: - ret_deferred.callback(res) - except: - pass - - return res - - def err(res): - try: - ret_deferred.errback(res) - except: - pass - - given_deferred.addCallbacks(callback=success, errback=err) - - timer = self.call_later(time_out, timed_out_fn) - - return ret_deferred + Returns: + an iterator over the chunks + """ + # make sure we can deal with iterables like lists too + sourceiter = iter(iterable) + # call islice until it returns an empty tuple + return iter(lambda: tuple(islice(sourceiter, size)), ()) diff --git a/synapse/util/async.py b/synapse/util/async.py index 1453faf0ef..a7094e2fb4 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,38 +13,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections +import logging +from contextlib import contextmanager + +from six.moves import range +from twisted.internet import defer +from twisted.internet.defer import CancelledError +from twisted.python import failure -from twisted.internet import defer, reactor +from synapse.util import Clock, logcontext, unwrapFirstError from .logcontext import ( - PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, + PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, ) -from synapse.util import unwrapFirstError - -from contextlib import contextmanager - -import logging logger = logging.getLogger(__name__) -@defer.inlineCallbacks -def sleep(seconds): - d = defer.Deferred() - with PreserveLoggingContext(): - reactor.callLater(seconds, d.callback, seconds) - res = yield d - defer.returnValue(res) - - -def run_on_reactor(): - """ This will cause the rest of the function to be invoked upon the next - iteration of the main loop - """ - return sleep(0) - - class ObservableDeferred(object): """Wraps a deferred object so that we can add observer deferreds. These observer deferreds do not affect the callback chain of the original @@ -53,6 +43,11 @@ class ObservableDeferred(object): Cancelling or otherwise resolving an observer will not affect the original ObservableDeferred. + + NB that it does not attempt to do anything with logcontexts; in general + you should probably make_deferred_yieldable the deferreds + returned by `observe`, and ensure that the original deferred runs its + callbacks in the sentinel logcontext. """ __slots__ = ["_deferred", "_observers", "_result"] @@ -68,7 +63,7 @@ class ObservableDeferred(object): try: # TODO: Handle errors here. self._observers.pop().callback(r) - except: + except Exception: pass return r @@ -78,7 +73,7 @@ class ObservableDeferred(object): try: # TODO: Handle errors here. self._observers.pop().errback(f) - except: + except Exception: pass if consumeErrors: @@ -151,77 +146,19 @@ def concurrently_execute(func, args, limit): def _concurrently_execute_inner(): try: while True: - yield func(it.next()) + yield func(next(it)) except StopIteration: pass - return preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(_concurrently_execute_inner)() - for _ in xrange(limit) + return logcontext.make_deferred_yieldable(defer.gatherResults([ + run_in_background(_concurrently_execute_inner) + for _ in range(limit) ], consumeErrors=True)).addErrback(unwrapFirstError) class Linearizer(object): - """Linearizes access to resources based on a key. Useful to ensure only one - thing is happening at a time on a given resource. - - Example: - - with (yield linearizer.queue("test_key")): - # do some work. - - """ - def __init__(self, name=None): - if name is None: - self.name = id(self) - else: - self.name = name - self.key_to_defer = {} - - @defer.inlineCallbacks - def queue(self, key): - # If there is already a deferred in the queue, we pull it out so that - # we can wait on it later. - # Then we replace it with a deferred that we resolve *after* the - # context manager has exited. - # We only return the context manager after the previous deferred has - # resolved. - # This all has the net effect of creating a chain of deferreds that - # wait for the previous deferred before starting their work. - current_defer = self.key_to_defer.get(key) - - new_defer = defer.Deferred() - self.key_to_defer[key] = new_defer - - if current_defer: - logger.info( - "Waiting to acquire linearizer lock %r for key %r", self.name, key - ) - try: - with PreserveLoggingContext(): - yield current_defer - except: - logger.exception("Unexpected exception in Linearizer") - - logger.info("Acquired linearizer lock %r for key %r", self.name, key) - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - logger.info("Releasing linearizer lock %r for key %r", self.name, key) - new_defer.callback(None) - current_d = self.key_to_defer.get(key) - if current_d is new_defer: - self.key_to_defer.pop(key, None) - - defer.returnValue(_ctx_manager()) - - -class Limiter(object): """Limits concurrent access to resources based on a key. Useful to ensure - only a few thing happen at a time on a given resource. + only a few things happen at a time on a given resource. Example: @@ -229,22 +166,31 @@ class Limiter(object): # do some work. """ - def __init__(self, max_count): + def __init__(self, name=None, max_count=1, clock=None): """ Args: - max_count(int): The maximum number of concurrent access + max_count(int): The maximum number of concurrent accesses """ + if name is None: + self.name = id(self) + else: + self.name = name + + if not clock: + from twisted.internet import reactor + clock = Clock(reactor) + self._clock = clock self.max_count = max_count # key_to_defer is a map from the key to a 2 element list where - # the first element is the number of things executing - # the second element is a list of deferreds for the things blocked from - # executing. + # the first element is the number of things executing, and + # the second element is an OrderedDict, where the keys are deferreds for the + # things blocked from executing. self.key_to_defer = {} @defer.inlineCallbacks def queue(self, key): - entry = self.key_to_defer.setdefault(key, [0, []]) + entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) # If the number of things executing is greater than the maximum # then add a deferred to the list of blocked items @@ -252,27 +198,71 @@ class Limiter(object): # this item so that it can continue executing. if entry[0] >= self.max_count: new_defer = defer.Deferred() - entry[1].append(new_defer) - with PreserveLoggingContext(): - yield new_defer + entry[1][new_defer] = 1 + + logger.info( + "Waiting to acquire linearizer lock %r for key %r", self.name, key, + ) + try: + yield make_deferred_yieldable(new_defer) + except Exception as e: + if isinstance(e, CancelledError): + logger.info( + "Cancelling wait for linearizer lock %r for key %r", + self.name, key, + ) + else: + logger.warn( + "Unexpected exception waiting for linearizer lock %r for key %r", + self.name, key, + ) + + # we just have to take ourselves back out of the queue. + del entry[1][new_defer] + raise + + logger.info("Acquired linearizer lock %r for key %r", self.name, key) + entry[0] += 1 + + # if the code holding the lock completes synchronously, then it + # will recursively run the next claimant on the list. That can + # relatively rapidly lead to stack exhaustion. This is essentially + # the same problem as http://twistedmatrix.com/trac/ticket/9304. + # + # In order to break the cycle, we add a cheeky sleep(0) here to + # ensure that we fall back to the reactor between each iteration. + # + # (This needs to happen while we hold the lock, and the context manager's exit + # code must be synchronous, so this is the only sensible place.) + yield self._clock.sleep(0) - entry[0] += 1 + else: + logger.info( + "Acquired uncontended linearizer lock %r for key %r", self.name, key, + ) + entry[0] += 1 @contextmanager def _ctx_manager(): try: yield finally: + logger.info("Releasing linearizer lock %r for key %r", self.name, key) + # We've finished executing so check if there are any things # blocked waiting to execute and start one of them entry[0] -= 1 - try: - entry[1].pop(0).callback(None) - except IndexError: - # If nothing else is executing for this key then remove it - # from the map - if entry[0] == 0: - self.key_to_defer.pop(key, None) + + if entry[1]: + (next_def, _) = entry[1].popitem(last=False) + + # we need to run the next thing in the sentinel context. + with PreserveLoggingContext(): + next_def.callback(None) + elif entry[0] == 0: + # We were the last thing for this key: remove it from the + # map. + del self.key_to_defer[key] defer.returnValue(_ctx_manager()) @@ -316,7 +306,7 @@ class ReadWriteLock(object): # We wait for the latest writer to finish writing. We can safely ignore # any existing readers... as they're readers. - yield curr_writer + yield make_deferred_yieldable(curr_writer) @contextmanager def _ctx_manager(): @@ -345,7 +335,7 @@ class ReadWriteLock(object): curr_readers.clear() self.key_to_current_writer[key] = new_defer - yield preserve_context_over_deferred(defer.gatherResults(to_wait_on)) + yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) @contextmanager def _ctx_manager(): @@ -357,3 +347,69 @@ class ReadWriteLock(object): self.key_to_current_writer.pop(key) defer.returnValue(_ctx_manager()) + + +class DeferredTimeoutError(Exception): + """ + This error is raised by default when a L{Deferred} times out. + """ + + +def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): + """ + Add a timeout to a deferred by scheduling it to be cancelled after + timeout seconds. + + This is essentially a backport of deferred.addTimeout, which was introduced + in twisted 16.5. + + If the deferred gets timed out, it errbacks with a DeferredTimeoutError, + unless a cancelable function was passed to its initialization or unless + a different on_timeout_cancel callable is provided. + + Args: + deferred (defer.Deferred): deferred to be timed out + timeout (Number): seconds to time out after + reactor (twisted.internet.reactor): the Twisted reactor to use + + on_timeout_cancel (callable): A callable which is called immediately + after the deferred times out, and not if this deferred is + otherwise cancelled before the timeout. + + It takes an arbitrary value, which is the value of the deferred at + that exact point in time (probably a CancelledError Failure), and + the timeout. + + The default callable (if none is provided) will translate a + CancelledError Failure into a DeferredTimeoutError. + """ + timed_out = [False] + + def time_it_out(): + timed_out[0] = True + deferred.cancel() + + delayed_call = reactor.callLater(timeout, time_it_out) + + def convert_cancelled(value): + if timed_out[0]: + to_call = on_timeout_cancel or _cancelled_to_timed_out_error + return to_call(value, timeout) + return value + + deferred.addBoth(convert_cancelled) + + def cancel_timeout(result): + # stop the pending call to cancel the deferred if it's been fired + if delayed_call.active(): + delayed_call.cancel() + return result + + deferred.addBoth(cancel_timeout) + + +def _cancelled_to_timed_out_error(value, timeout): + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise DeferredTimeoutError(timeout, "Deferred") + return value diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 4adae96681..7b065b195e 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -13,28 +13,87 @@ # See the License for the specific language governing permissions and # limitations under the License. -import synapse.metrics import os +import six +from six.moves import intern + +from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily + CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) -metrics = synapse.metrics.get_metrics_for("synapse.util.caches") + +def get_cache_factor_for(cache_name): + env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper() + factor = os.environ.get(env_var) + if factor: + return float(factor) + + return CACHE_SIZE_FACTOR + caches_by_name = {} -# cache_counter = metrics.register_cache( -# "cache", -# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, -# labels=["name"], -# ) - - -def register_cache(name, cache): - caches_by_name[name] = cache - return metrics.register_cache( - "cache", - lambda: len(cache), - name, - ) +collectors_by_name = {} + +cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) +cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) +cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"]) +cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"]) + +response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"]) +response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"]) +response_cache_evicted = Gauge( + "synapse_util_caches_response_cache:evicted_size", "", ["name"] +) +response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) + + +def register_cache(cache_type, cache_name, cache): + + # Check if the metric is already registered. Unregister it, if so. + # This usually happens during tests, as at runtime these caches are + # effectively singletons. + metric_name = "cache_%s_%s" % (cache_type, cache_name) + if metric_name in collectors_by_name.keys(): + REGISTRY.unregister(collectors_by_name[metric_name]) + + class CacheMetric(object): + + hits = 0 + misses = 0 + evicted_size = 0 + + def inc_hits(self): + self.hits += 1 + + def inc_misses(self): + self.misses += 1 + + def inc_evictions(self, size=1): + self.evicted_size += size + + def describe(self): + return [] + + def collect(self): + if cache_type == "response_cache": + response_cache_size.labels(cache_name).set(len(cache)) + response_cache_hits.labels(cache_name).set(self.hits) + response_cache_evicted.labels(cache_name).set(self.evicted_size) + response_cache_total.labels(cache_name).set(self.hits + self.misses) + else: + cache_size.labels(cache_name).set(len(cache)) + cache_hits.labels(cache_name).set(self.hits) + cache_evicted.labels(cache_name).set(self.evicted_size) + cache_total.labels(cache_name).set(self.hits + self.misses) + + yield GaugeMetricFamily("__unused", "") + + metric = CacheMetric() + REGISTRY.register(metric) + caches_by_name[cache_name] = cache + collectors_by_name[metric_name] = metric + return metric KNOWN_KEYS = { @@ -66,7 +125,9 @@ def intern_string(string): return None try: - string = string.encode("ascii") + if six.PY2: + string = string.encode("ascii") + return intern(string) except UnicodeEncodeError: return string diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index af65bfe7b8..f8a07df6b8 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,25 +13,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools +import inspect import logging +import threading +from collections import namedtuple + +import six +from six import itervalues, string_types + +from twisted.internet import defer +from synapse.util import logcontext, unwrapFirstError from synapse.util.async import ObservableDeferred -from synapse.util import unwrapFirstError, logcontext -from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.stringutils import to_ascii from . import register_cache -from twisted.internet import defer -from collections import namedtuple - -import functools -import inspect -import threading - - logger = logging.getLogger(__name__) @@ -39,12 +41,11 @@ _CacheSentinel = object() class CacheEntry(object): __slots__ = [ - "deferred", "sequence", "callbacks", "invalidated" + "deferred", "callbacks", "invalidated" ] - def __init__(self, deferred, sequence, callbacks): + def __init__(self, deferred, callbacks): self.deferred = deferred - self.sequence = sequence self.callbacks = set(callbacks) self.invalidated = False @@ -62,7 +63,6 @@ class Cache(object): "max_entries", "name", "keylen", - "sequence", "thread", "metrics", "_pending_deferred_cache", @@ -75,13 +75,16 @@ class Cache(object): self.cache = LruCache( max_size=max_entries, keylen=keylen, cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, + evicted_callback=self._on_evicted, ) self.name = name self.keylen = keylen - self.sequence = 0 self.thread = None - self.metrics = register_cache(name, self.cache) + self.metrics = register_cache("cache", name, self.cache) + + def _on_evicted(self, evicted_count): + self.metrics.inc_evictions(evicted_count) def check_thread(self): expected_thread = self.thread @@ -109,11 +112,10 @@ class Cache(object): callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _CacheSentinel) if val is not _CacheSentinel: - if val.sequence == self.sequence: - val.callbacks.update(callbacks) - if update_metrics: - self.metrics.inc_hits() - return val.deferred + val.callbacks.update(callbacks) + if update_metrics: + self.metrics.inc_hits() + return val.deferred val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: @@ -133,12 +135,9 @@ class Cache(object): self.check_thread() entry = CacheEntry( deferred=value, - sequence=self.sequence, callbacks=callbacks, ) - entry.callbacks.update(callbacks) - existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: existing_entry.invalidate() @@ -146,13 +145,25 @@ class Cache(object): self._pending_deferred_cache[key] = entry def shuffle(result): - if self.sequence == entry.sequence: - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - self.cache.set(key, result, entry.callbacks) - else: - entry.invalidate() + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, result, entry.callbacks) else: + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. entry.invalidate() return result @@ -164,25 +175,29 @@ class Cache(object): def invalidate(self, key): self.check_thread() + self.cache.pop(key, None) - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. if entry: entry.invalidate() - self.cache.pop(key, None) - def invalidate_many(self, key): self.check_thread() if not isinstance(key, tuple): raise TypeError( "The cache key must be a tuple not %r" % (type(key),) ) - self.sequence += 1 self.cache.del_multi(key) + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above entry_dict = self._pending_deferred_cache.pop(key, None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): @@ -190,8 +205,10 @@ class Cache(object): def invalidate_all(self): self.check_thread() - self.sequence += 1 self.cache.clear() + for entry in itervalues(self._pending_deferred_cache): + entry.invalidate() + self._pending_deferred_cache.clear() class _CacheDescriptorBase(object): @@ -294,7 +311,7 @@ class CacheDescriptor(_CacheDescriptorBase): orig, num_args=num_args, inlineCallbacks=inlineCallbacks, cache_context=cache_context) - max_entries = int(max_entries * CACHE_SIZE_FACTOR) + max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) self.max_entries = max_entries self.tree = tree @@ -376,9 +393,10 @@ class CacheDescriptor(_CacheDescriptorBase): ret.addErrback(onErr) - # If our cache_key is a string, try to convert to ascii to save - # a bit of space in large caches - if isinstance(cache_key, basestring): + # If our cache_key is a string on py2, try to convert to ascii + # to save a bit of space in large caches. Py3 does this + # internally automatically. + if six.PY2 and isinstance(cache_key, string_types): cache_key = to_ascii(cache_key) result_d = ObservableDeferred(ret, consumeErrors=True) @@ -549,7 +567,7 @@ class CacheListDescriptor(_CacheDescriptorBase): return results return logcontext.make_deferred_yieldable(defer.gatherResults( - cached_defers.values(), + list(cached_defers.values()), consumeErrors=True, ).addCallback(update_results_dict).addErrback( unwrapFirstError diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index d4105822b3..6c0b5a4094 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches.lrucache import LruCache -from collections import namedtuple -from . import register_cache -import threading import logging +import threading +from collections import namedtuple +from synapse.util.caches.lrucache import LruCache + +from . import register_cache logger = logging.getLogger(__name__) @@ -55,7 +56,7 @@ class DictionaryCache(object): __slots__ = [] self.sentinel = Sentinel() - self.metrics = register_cache(name, self.cache) + self.metrics = register_cache("dictionary", name, self.cache) def check_thread(self): expected_thread = self.thread @@ -107,34 +108,37 @@ class DictionaryCache(object): self.sequence += 1 self.cache.clear() - def update(self, sequence, key, value, full=False, known_absent=None): + def update(self, sequence, key, value, fetched_keys=None): """Updates the entry in the cache Args: sequence - key - value (dict): The value to update the cache with. - full (bool): Whether the given value is the full dict, or just a - partial subset there of. If not full then any existing entries - for the key will be updated. - known_absent (set): Set of keys that we know don't exist in the full - dict. + key (K) + value (dict[X,Y]): The value to update the cache with. + fetched_keys (None|set[X]): All of the dictionary keys which were + fetched from the database. + + If None, this is the complete value for key K. Otherwise, it + is used to infer a list of keys which we know don't exist in + the full dict. """ self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - if known_absent is None: - known_absent = set() - if full: - self._insert(key, value, known_absent) + if fetched_keys is None: + self._insert(key, value, set()) else: - self._update_or_insert(key, value, known_absent) + self._update_or_insert(key, value, fetched_keys) def _update_or_insert(self, key, value, known_absent): - entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {})) + # We pop and reinsert as we need to tell the cache the size may have + # changed + + entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) entry.value.update(value) entry.known_absent.update(known_absent) + self.cache[key] = entry def _insert(self, key, value, known_absent): self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 6ad53a6390..465adc54a8 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import register_cache - -from collections import OrderedDict import logging +from collections import OrderedDict +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.caches import register_cache logger = logging.getLogger(__name__) @@ -52,19 +52,22 @@ class ExpiringCache(object): self._cache = OrderedDict() - self.metrics = register_cache(cache_name, self) - self.iterable = iterable self._size_estimate = 0 + self.metrics = register_cache("expiring", cache_name, self) + def start(self): if not self._expiry_ms: # Don't bother starting the loop if things never expire return def f(): - self._prune_cache() + run_as_background_process( + "prune_cache_%s" % self._cache_name, + self._prune_cache, + ) self._clock.looping_call(f, self._expiry_ms / 2) @@ -79,7 +82,11 @@ class ExpiringCache(object): while self._max_len and len(self) > self._max_len: _key, value = self._cache.popitem(last=False) if self.iterable: - self._size_estimate -= len(value.value) + removed_len = len(value.value) + self.metrics.inc_evictions(removed_len) + self._size_estimate -= removed_len + else: + self.metrics.inc_evictions() def __getitem__(self, key): try: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index cf5fbb679c..b684f24e7b 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -14,8 +14,8 @@ # limitations under the License. -from functools import wraps import threading +from functools import wraps from synapse.util.caches.treecache import TreeCache @@ -49,7 +49,24 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): + def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None, + evicted_callback=None): + """ + Args: + max_size (int): + + keylen (int): + + cache_type (type): + type of underlying cache to be used. Typically one of dict + or TreeCache. + + size_callback (func(V) -> int | None): + + evicted_callback (func(int)|None): + if not None, called on eviction with the size of the evicted + entry + """ cache = cache_type() self.cache = cache # Used for introspection. list_root = _Node(None, None, None, None) @@ -61,8 +78,10 @@ class LruCache(object): def evict(): while cache_len() > max_size: todelete = list_root.prev_node - delete_node(todelete) + evicted_len = delete_node(todelete) cache.pop(todelete.key, None) + if evicted_callback: + evicted_callback(evicted_len) def synchronized(f): @wraps(f) @@ -111,12 +130,15 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + deleted_len = 1 if size_callback: - cached_cache_len[0] -= size_callback(node.value) + deleted_len = size_callback(node.value) + cached_cache_len[0] -= deleted_len for cb in node.callbacks: cb() node.callbacks.clear() + return deleted_len @synchronized def cache_get(key, default=None, callbacks=[]): @@ -132,14 +154,21 @@ class LruCache(object): def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: - if value != node.value: + # We sometimes store large objects, e.g. dicts, which cause + # the inequality check to take a long time. So let's only do + # the check if we have some callbacks to call. + if node.callbacks and value != node.value: for cb in node.callbacks: cb() node.callbacks.clear() - if size_callback: - cached_cache_len[0] -= size_callback(node.value) - cached_cache_len[0] += size_callback(value) + # We don't bother to protect this by value != node.value as + # generally size_callback will be cheap compared with equality + # checks. (For example, taking the size of two dicts is quicker + # than comparing them for equality.) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) node.callbacks.update(callbacks) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 00af539880..a8491b42d5 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -12,8 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + +from twisted.internet import defer from synapse.util.async import ObservableDeferred +from synapse.util.caches import register_cache +from synapse.util.logcontext import make_deferred_yieldable, run_in_background + +logger = logging.getLogger(__name__) class ResponseCache(object): @@ -24,20 +31,69 @@ class ResponseCache(object): used rather than trying to compute a new response. """ - def __init__(self, hs, timeout_ms=0): + def __init__(self, hs, name, timeout_ms=0): self.pending_result_cache = {} # Requests that haven't finished yet. self.clock = hs.get_clock() self.timeout_sec = timeout_ms / 1000. + self._name = name + self._metrics = register_cache( + "response_cache", name, self + ) + + def size(self): + return len(self.pending_result_cache) + + def __len__(self): + return self.size() + def get(self, key): + """Look up the given key. + + Can return either a new Deferred (which also doesn't follow the synapse + logcontext rules), or, if the request has completed, the actual + result. You will probably want to make_deferred_yieldable the result. + + If there is no entry for the key, returns None. It is worth noting that + this means there is no way to distinguish a completed result of None + from an absent cache entry. + + Args: + key (hashable): + + Returns: + twisted.internet.defer.Deferred|None|E: None if there is no entry + for this key; otherwise either a deferred result or the result + itself. + """ result = self.pending_result_cache.get(key) if result is not None: + self._metrics.inc_hits() return result.observe() else: + self._metrics.inc_misses() return None def set(self, key, deferred): + """Set the entry for the given key to the given deferred. + + *deferred* should run its callbacks in the sentinel logcontext (ie, + you should wrap normal synapse deferreds with + logcontext.run_in_background). + + Can return either a new Deferred (which also doesn't follow the synapse + logcontext rules), or, if *deferred* was already complete, the actual + result. You will probably want to make_deferred_yieldable the result. + + Args: + key (hashable): + deferred (twisted.internet.defer.Deferred[T): + + Returns: + twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual + result. + """ result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result @@ -53,3 +109,52 @@ class ResponseCache(object): result.addBoth(remove) return result.observe() + + def wrap(self, key, callback, *args, **kwargs): + """Wrap together a *get* and *set* call, taking care of logcontexts + + First looks up the key in the cache, and if it is present makes it + follow the synapse logcontext rules and returns it. + + Otherwise, makes a call to *callback(*args, **kwargs)*, which should + follow the synapse logcontext rules, and adds the result to the cache. + + Example usage: + + @defer.inlineCallbacks + def handle_request(request): + # etc + defer.returnValue(result) + + result = yield response_cache.wrap( + key, + handle_request, + request, + ) + + Args: + key (hashable): key to get/set in the cache + + callback (callable): function to call if the key is not found in + the cache + + *args: positional parameters to pass to the callback, if it is used + + **kwargs: named paramters to pass to the callback, if it is used + + Returns: + twisted.internet.defer.Deferred: yieldable result + """ + result = self.get(key) + if not result: + logger.info("[%s]: no cached result for [%s], calculating new one", + self._name, key) + d = run_in_background(callback, *args, **kwargs) + result = self.set(key, d) + elif not isinstance(result, defer.Deferred) or result.called: + logger.info("[%s]: using completed cached result for [%s]", + self._name, key) + else: + logger.info("[%s]: using incomplete cached result for [%s]", + self._name, key) + return make_deferred_yieldable(result) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 941d873ab8..f2bde74dc5 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR - - -from blist import sorteddict import logging +from sortedcontainers import SortedDict + +from synapse.util import caches logger = logging.getLogger(__name__) @@ -32,16 +31,18 @@ class StreamChangeCache(object): entities that may have changed since that position. If position key is too old then the cache will simply return all given entities. """ - def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}): - self._max_size = int(max_size * CACHE_SIZE_FACTOR) + + def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None): + self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) self._entity_to_key = {} - self._cache = sorteddict() + self._cache = SortedDict() self._earliest_known_stream_pos = current_stream_pos self.name = name - self.metrics = register_cache(self.name, self._cache) + self.metrics = caches.register_cache("cache", self.name, self._cache) - for entity, stream_pos in prefilled_cache.items(): - self.entity_has_changed(entity, stream_pos) + if prefilled_cache: + for entity, stream_pos in prefilled_cache.items(): + self.entity_has_changed(entity, stream_pos) def has_entity_changed(self, entity, stream_pos): """Returns True if the entity may have been updated since stream_pos @@ -65,22 +66,25 @@ class StreamChangeCache(object): return False def get_entities_changed(self, entities, stream_pos): - """Returns subset of entities that have had new things since the - given position. If the position is too old it will just return the given list. + """ + Returns subset of entities that have had new things since the given + position. Entities unknown to the cache will be returned. If the + position is too old it will just return the given list. """ assert type(stream_pos) is int if stream_pos >= self._earliest_known_stream_pos: - keys = self._cache.keys() - i = keys.bisect_right(stream_pos) + changed_entities = { + self._cache[k] for k in self._cache.islice( + start=self._cache.bisect_right(stream_pos), + ) + } - result = set( - self._cache[k] for k in keys[i:] - ).intersection(entities) + result = changed_entities.intersection(entities) self.metrics.inc_hits() else: - result = entities + result = set(entities) self.metrics.inc_misses() return result @@ -90,12 +94,13 @@ class StreamChangeCache(object): """ assert type(stream_pos) is int + if not self._cache: + # If we have no cache, nothing can have changed. + return False + if stream_pos >= self._earliest_known_stream_pos: self.metrics.inc_hits() - keys = self._cache.keys() - i = keys.bisect_right(stream_pos) - - return i < len(keys) + return self._cache.bisect_right(stream_pos) < len(self._cache) else: self.metrics.inc_misses() return True @@ -107,10 +112,8 @@ class StreamChangeCache(object): assert type(stream_pos) is int if stream_pos >= self._earliest_known_stream_pos: - keys = self._cache.keys() - i = keys.bisect_right(stream_pos) - - return [self._cache[k] for k in keys[i:]] + return [self._cache[k] for k in self._cache.islice( + start=self._cache.bisect_right(stream_pos))] else: return None @@ -129,8 +132,10 @@ class StreamChangeCache(object): self._entity_to_key[entity] = stream_pos while len(self._cache) > self._max_size: - k, r = self._cache.popitem() - self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max( + k, self._earliest_known_stream_pos, + ) self._entity_to_key.pop(r, None) def get_max_pos_of_last_change(self, entity): diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index fcc341a6b7..dd4c9e6067 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -1,3 +1,5 @@ +from six import itervalues + SENTINEL = object() @@ -49,7 +51,7 @@ class TreeCache(object): if popped is SENTINEL: return default - node_and_keys = zip(nodes, key) + node_and_keys = list(zip(nodes, key)) node_and_keys.reverse() node_and_keys.append((self.root, None)) @@ -76,7 +78,7 @@ def iterate_tree_cache_entry(d): can contain dicts. """ if isinstance(d, dict): - for value_d in d.itervalues(): + for value_d in itervalues(d): for value in iterate_tree_cache_entry(value_d): yield value else: diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index e68f94ce77..194da87639 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -13,32 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_context_over_fn -) - -from synapse.util import unwrapFirstError - import logging +from twisted.internet import defer + +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) + distributor.fire("user_left_room", user=user, room_id=room_id) def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) + distributor.fire("user_joined_room", user=user, room_id=room_id) class Distributor(object): @@ -52,9 +42,7 @@ class Distributor(object): model will do for today. """ - def __init__(self, suppress_failures=True): - self.suppress_failures = suppress_failures - + def __init__(self): self.signals = {} self.pre_registration = {} @@ -64,7 +52,6 @@ class Distributor(object): self.signals[name] = Signal( name, - suppress_failures=self.suppress_failures, ) if name in self.pre_registration: @@ -83,10 +70,18 @@ class Distributor(object): self.pre_registration[name].append(observer) def fire(self, name, *args, **kwargs): + """Dispatches the given signal to the registered observers. + + Runs the observers as a background process. Does not return a deferred. + """ if name not in self.signals: raise KeyError("%r does not have a signal named %s" % (self, name)) - return self.signals[name].fire(*args, **kwargs) + run_as_background_process( + name, + self.signals[name].fire, + *args, **kwargs + ) class Signal(object): @@ -99,9 +94,8 @@ class Signal(object): method into all of the observers. """ - def __init__(self, name, suppress_failures): + def __init__(self, name): self.name = name - self.suppress_failures = suppress_failures self.observers = [] def observe(self, observer): @@ -111,7 +105,6 @@ class Signal(object): Each observer callable may return a Deferred.""" self.observers.append(observer) - @defer.inlineCallbacks def fire(self, *args, **kwargs): """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is @@ -129,22 +122,17 @@ class Signal(object): failure.type, failure.value, failure.getTracebackObject())) - if not self.suppress_failures: - return failure return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) - with PreserveLoggingContext(): - deferreds = [ - do(observer) - for observer in self.observers - ] - - res = yield defer.gatherResults( - deferreds, consumeErrors=True - ).addErrback(unwrapFirstError) + deferreds = [ + run_in_background(do, o) + for o in self.observers + ] - defer.returnValue(res) + return make_deferred_yieldable(defer.gatherResults( + deferreds, consumeErrors=True, + )) def __repr__(self): return "<Signal name=%r>" % (self.name,) diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py new file mode 100644 index 0000000000..629ed44149 --- /dev/null +++ b/synapse/util/file_consumer.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from six.moves import queue + +from twisted.internet import threads + +from synapse.util.logcontext import make_deferred_yieldable, run_in_background + + +class BackgroundFileConsumer(object): + """A consumer that writes to a file like object. Supports both push + and pull producers + + Args: + file_obj (file): The file like object to write to. Closed when + finished. + reactor (twisted.internet.reactor): the Twisted reactor to use + """ + + # For PushProducers pause if we have this many unwritten slices + _PAUSE_ON_QUEUE_SIZE = 5 + # And resume once the size of the queue is less than this + _RESUME_ON_QUEUE_SIZE = 2 + + def __init__(self, file_obj, reactor): + self._file_obj = file_obj + + self._reactor = reactor + + # Producer we're registered with + self._producer = None + + # True if PushProducer, false if PullProducer + self.streaming = False + + # For PushProducers, indicates whether we've paused the producer and + # need to call resumeProducing before we get more data. + self._paused_producer = False + + # Queue of slices of bytes to be written. When producer calls + # unregister a final None is sent. + self._bytes_queue = queue.Queue() + + # Deferred that is resolved when finished writing + self._finished_deferred = None + + # If the _writer thread throws an exception it gets stored here. + self._write_exception = None + + def registerProducer(self, producer, streaming): + """Part of IConsumer interface + + Args: + producer (IProducer) + streaming (bool): True if push based producer, False if pull + based. + """ + if self._producer: + raise Exception("registerProducer called twice") + + self._producer = producer + self.streaming = streaming + self._finished_deferred = run_in_background( + threads.deferToThreadPool, + self._reactor, + self._reactor.getThreadPool(), + self._writer, + ) + if not streaming: + self._producer.resumeProducing() + + def unregisterProducer(self): + """Part of IProducer interface + """ + self._producer = None + if not self._finished_deferred.called: + self._bytes_queue.put_nowait(None) + + def write(self, bytes): + """Part of IProducer interface + """ + if self._write_exception: + raise self._write_exception + + if self._finished_deferred.called: + raise Exception("consumer has closed") + + self._bytes_queue.put_nowait(bytes) + + # If this is a PushProducer and the queue is getting behind + # then we pause the producer. + if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: + self._paused_producer = True + self._producer.pauseProducing() + + def _writer(self): + """This is run in a background thread to write to the file. + """ + try: + while self._producer or not self._bytes_queue.empty(): + # If we've paused the producer check if we should resume the + # producer. + if self._producer and self._paused_producer: + if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: + self._reactor.callFromThread(self._resume_paused_producer) + + bytes = self._bytes_queue.get() + + # If we get a None (or empty list) then that's a signal used + # to indicate we should check if we should stop. + if bytes: + self._file_obj.write(bytes) + + # If its a pull producer then we need to explicitly ask for + # more stuff. + if not self.streaming and self._producer: + self._reactor.callFromThread(self._producer.resumeProducing) + except Exception as e: + self._write_exception = e + raise + finally: + self._file_obj.close() + + def wait(self): + """Returns a deferred that resolves when finished writing to file + """ + return make_deferred_yieldable(self._finished_deferred) + + def _resume_paused_producer(self): + """Gets called if we should resume producing after being paused + """ + if self._paused_producer and self._producer: + self._paused_producer = False + self._producer.resumeProducing() diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 6322f0f55c..581c6052ac 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,18 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six import string_types + +from canonicaljson import json from frozendict import frozendict def freeze(o): - t = type(o) - if t is dict: + if isinstance(o, dict): return frozendict({k: freeze(v) for k, v in o.items()}) - if t is frozendict: + if isinstance(o, frozendict): return o - if t is str or t is unicode: + if isinstance(o, string_types): return o try: @@ -36,11 +38,10 @@ def freeze(o): def unfreeze(o): - t = type(o) - if t is dict or t is frozendict: + if isinstance(o, (dict, frozendict)): return dict({k: unfreeze(v) for k, v in o.items()}) - if t is str or t is unicode: + if isinstance(o, string_types): return o try: @@ -49,3 +50,21 @@ def unfreeze(o): pass return o + + +def _handle_frozendict(obj): + """Helper for EventEncoder. Makes frozendicts serializable by returning + the underlying dict + """ + if type(obj) is frozendict: + # fishing the protected dict out of the object is a bit nasty, + # but we don't really want the overhead of copying the dict. + return obj._dict + raise TypeError('Object of type %s is not JSON serializable' % + obj.__class__.__name__) + + +# A JSONEncoder which is capable of encoding frozendics without barfing +frozendict_json_encoder = json.JSONEncoder( + default=_handle_frozendict, +) diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 45be47159a..2d7ddc1cbe 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.web.resource import Resource - import logging +from twisted.web.resource import NoResource + logger = logging.getLogger(__name__) @@ -40,12 +40,15 @@ def create_resource_tree(desired_tree, root_resource): # extra resources to existing nodes. See self._resource_id for the key. resource_mappings = {} for full_path, res in desired_tree.items(): + # twisted requires all resources to be bytes + full_path = full_path.encode("utf-8") + logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource - for path_seg in full_path.split('/')[1:-1]: + for path_seg in full_path.split(b'/')[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" - child_resource = Resource() + child_resource = NoResource() last_resource.putChild(path_seg, child_resource) res_id = _resource_id(last_resource, path_seg) resource_mappings[res_id] = child_resource @@ -57,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource): # =========================== # now attach the actual desired resource - last_path_seg = full_path.split('/')[-1] + last_path_seg = full_path.split(b'/')[-1] # if there is already a resource here, thieve its children and # replace it diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 990216145e..8dcae50b39 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -22,10 +22,10 @@ them. See doc/log_contexts.rst for details on how this works. """ -from twisted.internet import defer - -import threading import logging +import threading + +from twisted.internet import defer logger = logging.getLogger(__name__) @@ -42,23 +42,128 @@ try: def get_thread_resource_usage(): return resource.getrusage(RUSAGE_THREAD) -except: +except Exception: # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we # won't track resource usage by returning None. def get_thread_resource_usage(): return None +class ContextResourceUsage(object): + """Object for tracking the resources used by a log context + + Attributes: + ru_utime (float): user CPU time (in seconds) + ru_stime (float): system CPU time (in seconds) + db_txn_count (int): number of database transactions done + db_sched_duration_sec (float): amount of time spent waiting for a + database connection + db_txn_duration_sec (float): amount of time spent doing database + transactions (excluding scheduling time) + evt_db_fetch_count (int): number of events requested from the database + """ + + __slots__ = [ + "ru_stime", "ru_utime", + "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec", + "evt_db_fetch_count", + ] + + def __init__(self, copy_from=None): + """Create a new ContextResourceUsage + + Args: + copy_from (ContextResourceUsage|None): if not None, an object to + copy stats from + """ + if copy_from is None: + self.reset() + else: + self.ru_utime = copy_from.ru_utime + self.ru_stime = copy_from.ru_stime + self.db_txn_count = copy_from.db_txn_count + + self.db_txn_duration_sec = copy_from.db_txn_duration_sec + self.db_sched_duration_sec = copy_from.db_sched_duration_sec + self.evt_db_fetch_count = copy_from.evt_db_fetch_count + + def copy(self): + return ContextResourceUsage(copy_from=self) + + def reset(self): + self.ru_stime = 0. + self.ru_utime = 0. + self.db_txn_count = 0 + + self.db_txn_duration_sec = 0 + self.db_sched_duration_sec = 0 + self.evt_db_fetch_count = 0 + + def __repr__(self): + return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', " + "db_txn_count='%r', db_txn_duration_sec='%r', " + "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % ( + self.ru_stime, + self.ru_utime, + self.db_txn_count, + self.db_txn_duration_sec, + self.db_sched_duration_sec, + self.evt_db_fetch_count,) + + def __iadd__(self, other): + """Add another ContextResourceUsage's stats to this one's. + + Args: + other (ContextResourceUsage): the other resource usage object + """ + self.ru_utime += other.ru_utime + self.ru_stime += other.ru_stime + self.db_txn_count += other.db_txn_count + self.db_txn_duration_sec += other.db_txn_duration_sec + self.db_sched_duration_sec += other.db_sched_duration_sec + self.evt_db_fetch_count += other.evt_db_fetch_count + return self + + def __isub__(self, other): + self.ru_utime -= other.ru_utime + self.ru_stime -= other.ru_stime + self.db_txn_count -= other.db_txn_count + self.db_txn_duration_sec -= other.db_txn_duration_sec + self.db_sched_duration_sec -= other.db_sched_duration_sec + self.evt_db_fetch_count -= other.evt_db_fetch_count + return self + + def __add__(self, other): + res = ContextResourceUsage(copy_from=self) + res += other + return res + + def __sub__(self, other): + res = ContextResourceUsage(copy_from=self) + res -= other + return res + + class LoggingContext(object): """Additional context for log formatting. Contexts are scoped within a "with" block. + + If a parent is given when creating a new context, then: + - logging fields are copied from the parent to the new context on entry + - when the new context exits, the cpu usage stats are copied from the + child to the parent + Args: name (str): Name for the context for debugging. + parent_context (LoggingContext|None): The parent of the new context """ __slots__ = [ - "previous_context", "name", "usage_start", "usage_end", "main_thread", - "__dict__", "tag", "alive", + "previous_context", "name", "parent_context", + "_resource_usage", + "usage_start", + "main_thread", "alive", + "request", "tag", ] thread_local = threading.local() @@ -80,32 +185,49 @@ class LoggingContext(object): def stop(self): pass - def add_database_transaction(self, duration_ms): + def add_database_transaction(self, duration_sec): + pass + + def add_database_scheduled(self, sched_sec): + pass + + def record_event_fetch(self, event_count): pass def __nonzero__(self): return False + __bool__ = __nonzero__ # python3 sentinel = Sentinel() - def __init__(self, name=None): + def __init__(self, name=None, parent_context=None): self.previous_context = LoggingContext.current_context() self.name = name - self.ru_stime = 0. - self.ru_utime = 0. - self.db_txn_count = 0 - self.db_txn_duration = 0. + + # track the resources used by this context so far + self._resource_usage = ContextResourceUsage() + + # If alive has the thread resource usage when the logcontext last + # became active. self.usage_start = None + self.main_thread = threading.current_thread() + self.request = None self.tag = "" self.alive = True + self.parent_context = parent_context + def __str__(self): return "%s@%x" % (self.name, id(self)) @classmethod def current_context(cls): - """Get the current logging context from thread local storage""" + """Get the current logging context from thread local storage + + Returns: + LoggingContext: the current logging context + """ return getattr(cls.thread_local, "current_context", cls.sentinel) @classmethod @@ -133,18 +255,22 @@ class LoggingContext(object): self.previous_context, old_context ) self.alive = True + + if self.parent_context is not None: + self.parent_context.copy_to(self) + return self def __exit__(self, type, value, traceback): """Restore the logging context in thread local storage to the state it was before this context was entered. Returns: - None to avoid suppressing any exeptions that were thrown. + None to avoid suppressing any exceptions that were thrown. """ current = self.set_current_context(self.previous_context) if current is not self: if current is self.sentinel: - logger.debug("Expected logging context %s has been lost", self) + logger.warn("Expected logging context %s has been lost", self) else: logger.warn( "Current logging context %s is not expected context %s", @@ -154,47 +280,91 @@ class LoggingContext(object): self.previous_context = None self.alive = False + # if we have a parent, pass our CPU usage stats on + if self.parent_context is not None: + self.parent_context._resource_usage += self._resource_usage + + # reset them in case we get entered again + self._resource_usage.reset() + def copy_to(self, record): - """Copy fields from this context to the record""" - for key, value in self.__dict__.items(): - setattr(record, key, value) + """Copy logging fields from this context to a log record or + another LoggingContext + """ - record.ru_utime, record.ru_stime = self.get_resource_usage() + # 'request' is the only field we currently use in the logger, so that's + # all we need to copy + record.request = self.request def start(self): if threading.current_thread() is not self.main_thread: + logger.warning("Started logcontext %s on different thread", self) return - if self.usage_start and self.usage_end: - self.ru_utime += self.usage_end.ru_utime - self.usage_start.ru_utime - self.ru_stime += self.usage_end.ru_stime - self.usage_start.ru_stime - self.usage_start = None - self.usage_end = None - + # If we haven't already started record the thread resource usage so + # far if not self.usage_start: self.usage_start = get_thread_resource_usage() def stop(self): if threading.current_thread() is not self.main_thread: + logger.warning("Stopped logcontext %s on different thread", self) + return + + # When we stop, let's record the cpu used since we started + if not self.usage_start: + logger.warning( + "Called stop on logcontext %s without calling start", self, + ) return - if self.usage_start: - self.usage_end = get_thread_resource_usage() + usage_end = get_thread_resource_usage() + + self._resource_usage.ru_utime += usage_end.ru_utime - self.usage_start.ru_utime + self._resource_usage.ru_stime += usage_end.ru_stime - self.usage_start.ru_stime + + self.usage_start = None def get_resource_usage(self): - ru_utime = self.ru_utime - ru_stime = self.ru_stime + """Get resources used by this logcontext so far. - if self.usage_start and threading.current_thread() is self.main_thread: + Returns: + ContextResourceUsage: a *copy* of the object tracking resource + usage so far + """ + # we always return a copy, for consistency + res = self._resource_usage.copy() + + # If we are on the correct thread and we're currently running then we + # can include resource usage so far. + is_main_thread = threading.current_thread() is self.main_thread + if self.alive and self.usage_start and is_main_thread: current = get_thread_resource_usage() - ru_utime += current.ru_utime - self.usage_start.ru_utime - ru_stime += current.ru_stime - self.usage_start.ru_stime + res.ru_utime += current.ru_utime - self.usage_start.ru_utime + res.ru_stime += current.ru_stime - self.usage_start.ru_stime - return ru_utime, ru_stime + return res - def add_database_transaction(self, duration_ms): - self.db_txn_count += 1 - self.db_txn_duration += duration_ms / 1000. + def add_database_transaction(self, duration_sec): + self._resource_usage.db_txn_count += 1 + self._resource_usage.db_txn_duration_sec += duration_sec + + def add_database_scheduled(self, sched_sec): + """Record a use of the database pool + + Args: + sched_sec (float): number of seconds it took us to get a + connection + """ + self._resource_usage.db_sched_duration_sec += sched_sec + + def record_event_fetch(self, event_count): + """Record a number of events being fetched from the db + + Args: + event_count (int): number of events being fetched + """ + self._resource_usage.evt_db_fetch_count += event_count class LoggingContextFilter(logging.Filter): @@ -248,7 +418,7 @@ class PreserveLoggingContext(object): context = LoggingContext.set_current_context(self.current_context) if context != self.new_context: - logger.debug( + logger.warn( "Unexpected logging context: %s is not %s", context, self.new_context, ) @@ -261,105 +431,62 @@ class PreserveLoggingContext(object): ) -class _PreservingContextDeferred(defer.Deferred): - """A deferred that ensures that all callbacks and errbacks are called with - the given logging context. - """ - def __init__(self, context): - self._log_context = context - defer.Deferred.__init__(self) - - def addCallbacks(self, callback, errback=None, - callbackArgs=None, callbackKeywords=None, - errbackArgs=None, errbackKeywords=None): - callback = self._wrap_callback(callback) - errback = self._wrap_callback(errback) - return defer.Deferred.addCallbacks( - self, callback, - errback=errback, - callbackArgs=callbackArgs, - callbackKeywords=callbackKeywords, - errbackArgs=errbackArgs, - errbackKeywords=errbackKeywords, - ) +def preserve_fn(f): + """Function decorator which wraps the function with run_in_background""" + def g(*args, **kwargs): + return run_in_background(f, *args, **kwargs) + return g - def _wrap_callback(self, f): - def g(res, *args, **kwargs): - with PreserveLoggingContext(self._log_context): - res = f(res, *args, **kwargs) - return res - return g +def run_in_background(f, *args, **kwargs): + """Calls a function, ensuring that the current context is restored after + return from the function, and that the sentinel context is set once the + deferred returned by the function completes. -def preserve_context_over_fn(fn, *args, **kwargs): - """Takes a function and invokes it with the given arguments, but removes - and restores the current logging context while doing so. + Useful for wrapping functions that return a deferred which you don't yield + on (for instance because you want to pass it to deferred.gatherResults()). - If the result is a deferred, call preserve_context_over_deferred before - returning it. + Note that if you completely discard the result, you should make sure that + `f` doesn't raise any deferred exceptions, otherwise a scary-looking + CRITICAL error about an unhandled error will be logged without much + indication about where it came from. """ - with PreserveLoggingContext(): - res = fn(*args, **kwargs) + current = LoggingContext.current_context() + try: + res = f(*args, **kwargs) + except: # noqa: E722 + # the assumption here is that the caller doesn't want to be disturbed + # by synchronous exceptions, so let's turn them into Failures. + return defer.fail() - if isinstance(res, defer.Deferred): - return preserve_context_over_deferred(res) - else: + if not isinstance(res, defer.Deferred): return res - -def preserve_context_over_deferred(deferred, context=None): - """Given a deferred wrap it such that any callbacks added later to it will - be invoked with the current context. - - Deprecated: this almost certainly doesn't do want you want, ie make - the deferred follow the synapse logcontext rules: try - ``make_deferred_yieldable`` instead. - """ - if context is None: - context = LoggingContext.current_context() - d = _PreservingContextDeferred(context) - deferred.chainDeferred(d) - return d - - -def preserve_fn(f): - """Wraps a function, to ensure that the current context is restored after - return from the function, and that the sentinel context is set once the - deferred returned by the funtion completes. - - Useful for wrapping functions that return a deferred which you don't yield - on. - """ - def reset_context(result): - LoggingContext.set_current_context(LoggingContext.sentinel) - return result - - def g(*args, **kwargs): - current = LoggingContext.current_context() - res = f(*args, **kwargs) - if isinstance(res, defer.Deferred) and not res.called: - # The function will have reset the context before returning, so - # we need to restore it now. - LoggingContext.set_current_context(current) - - # The original context will be restored when the deferred - # completes, but there is nothing waiting for it, so it will - # get leaked into the reactor or some other function which - # wasn't expecting it. We therefore need to reset the context - # here. - # - # (If this feels asymmetric, consider it this way: we are - # effectively forking a new thread of execution. We are - # probably currently within a ``with LoggingContext()`` block, - # which is supposed to have a single entry and exit point. But - # by spawning off another deferred, we are effectively - # adding a new exit point.) - res.addBoth(reset_context) + if res.called and not res.paused: + # The function should have maintained the logcontext, so we can + # optimise out the messing about return res - return g + + # The function may have reset the context before returning, so + # we need to restore it now. + ctx = LoggingContext.set_current_context(current) + + # The original context will be restored when the deferred + # completes, but there is nothing waiting for it, so it will + # get leaked into the reactor or some other function which + # wasn't expecting it. We therefore need to reset the context + # here. + # + # (If this feels asymmetric, consider it this way: we are + # effectively forking a new thread of execution. We are + # probably currently within a ``with LoggingContext()`` block, + # which is supposed to have a single entry and exit point. But + # by spawning off another deferred, we are effectively + # adding a new exit point.) + res.addBoth(_set_context_cb, ctx) + return res -@defer.inlineCallbacks def make_deferred_yieldable(deferred): """Given a deferred, make it follow the Synapse logcontext rules: @@ -371,11 +498,27 @@ def make_deferred_yieldable(deferred): returning a deferred. Then, when the deferred completes, restores the current logcontext before running callbacks/errbacks. - (This is more-or-less the opposite operation to preserve_fn.) + (This is more-or-less the opposite operation to run_in_background.) """ - with PreserveLoggingContext(): - r = yield deferred - defer.returnValue(r) + if not isinstance(deferred, defer.Deferred): + return deferred + + if deferred.called and not deferred.paused: + # it looks like this deferred is ready to run any callbacks we give it + # immediately. We may as well optimise out the logcontext faffery. + return deferred + + # ok, we can't be sure that a yield won't block, so let's reset the + # logcontext, and add a callback to the deferred to restore it. + prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + deferred.addBoth(_set_context_cb, prev_context) + return deferred + + +def _set_context_cb(result, context): + """A callback function which just sets the logging context""" + LoggingContext.set_current_context(context) + return result # modules to ignore in `logcontext_tracer` diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py new file mode 100644 index 0000000000..a46bc47ce3 --- /dev/null +++ b/synapse/util/logformatter.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import traceback + +from six import StringIO + + +class LogFormatter(logging.Formatter): + """Log formatter which gives more detail for exceptions + + This is the same as the standard log formatter, except that when logging + exceptions [typically via log.foo("msg", exc_info=1)], it prints the + sequence that led up to the point at which the exception was caught. + (Normally only stack frames between the point the exception was raised and + where it was caught are logged). + """ + def __init__(self, *args, **kwargs): + super(LogFormatter, self).__init__(*args, **kwargs) + + def formatException(self, ei): + sio = StringIO() + (typ, val, tb) = ei + + # log the stack above the exception capture point if possible, but + # check that we actually have an f_back attribute to work around + # https://twistedmatrix.com/trac/ticket/9305 + + if tb and hasattr(tb.tb_frame, 'f_back'): + sio.write("Capture point (most recent call last):\n") + traceback.print_stack(tb.tb_frame.f_back, None, sio) + + traceback.print_exception(typ, val, tb, None, sio) + s = sio.getvalue() + sio.close() + if s[-1:] == "\n": + s = s[:-1] + return s diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index 3a83828d25..62a00189cc 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -14,13 +14,11 @@ # limitations under the License. -from inspect import getcallargs -from functools import wraps - -import logging import inspect +import logging import time - +from functools import wraps +from inspect import getcallargs _TIME_FUNC_ID = 0 @@ -96,7 +94,7 @@ def time_function(f): id = _TIME_FUNC_ID _TIME_FUNC_ID += 1 - start = time.clock() * 1000 + start = time.clock() try: _log_debug_as_f( @@ -107,10 +105,10 @@ def time_function(f): r = f(*args, **kwargs) finally: - end = time.clock() * 1000 + end = time.clock() _log_debug_as_f( f, - "[FUNC END] {%s-%d} %f", + "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start,), ) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 97e0f00b67..14be3c7396 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.conch.manhole import ColoredManhole -from twisted.conch.insults import insults from twisted.conch import manhole_ssh -from twisted.cred import checkers, portal +from twisted.conch.insults import insults +from twisted.conch.manhole import ColoredManhole from twisted.conch.ssh.keys import Key +from twisted.cred import checkers, portal PUBLIC_KEY = ( "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az" diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4ea930d3e8..97f1267380 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,40 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging +from functools import wraps -from synapse.util.logcontext import LoggingContext -import synapse.metrics +from prometheus_client import Counter -from functools import wraps -import logging +from twisted.internet import defer +from synapse.util.logcontext import LoggingContext logger = logging.getLogger(__name__) +block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"]) -metrics = synapse.metrics.get_metrics_for(__name__) +block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"]) -block_timer = metrics.register_distribution( - "block_timer", - labels=["block_name"] -) +block_ru_utime = Counter( + "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]) -block_ru_utime = metrics.register_distribution( - "block_ru_utime", labels=["block_name"] -) +block_ru_stime = Counter( + "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]) -block_ru_stime = metrics.register_distribution( - "block_ru_stime", labels=["block_name"] -) +block_db_txn_count = Counter( + "synapse_util_metrics_block_db_txn_count", "", ["block_name"]) -block_db_txn_count = metrics.register_distribution( - "block_db_txn_count", labels=["block_name"] -) +# seconds spent waiting for db txns, excluding scheduling time, in this block +block_db_txn_duration = Counter( + "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]) -block_db_txn_duration = metrics.register_distribution( - "block_db_txn_duration", labels=["block_name"] -) +# seconds spent waiting for a db connection, in this block +block_db_sched_duration = Counter( + "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]) def measure_func(name): @@ -63,8 +60,9 @@ def measure_func(name): class Measure(object): __slots__ = [ - "clock", "name", "start_context", "start", "new_context", "ru_utime", - "ru_stime", "db_txn_count", "db_txn_duration", "created_context" + "clock", "name", "start_context", "start", + "created_context", + "start_usage", ] def __init__(self, clock, name): @@ -75,23 +73,23 @@ class Measure(object): self.created_context = False def __enter__(self): - self.start = self.clock.time_msec() + self.start = self.clock.time() self.start_context = LoggingContext.current_context() if not self.start_context: self.start_context = LoggingContext("Measure") self.start_context.__enter__() self.created_context = True - self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() - self.db_txn_count = self.start_context.db_txn_count - self.db_txn_duration = self.start_context.db_txn_duration + self.start_usage = self.start_context.get_resource_usage() def __exit__(self, exc_type, exc_val, exc_tb): if isinstance(exc_type, Exception) or not self.start_context: return - duration = self.clock.time_msec() - self.start - block_timer.inc_by(duration, self.name) + duration = self.clock.time() - self.start + + block_counter.labels(self.name).inc() + block_timer.labels(self.name).inc(duration) context = LoggingContext.current_context() @@ -106,16 +104,19 @@ class Measure(object): logger.warn("Expected context. (%r)", self.name) return - ru_utime, ru_stime = context.get_resource_usage() - - block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name) - block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name) - block_db_txn_count.inc_by( - context.db_txn_count - self.db_txn_count, self.name - ) - block_db_txn_duration.inc_by( - context.db_txn_duration - self.db_txn_duration, self.name - ) + current = context.get_resource_usage() + usage = current - self.start_usage + try: + block_ru_utime.labels(self.name).inc(usage.ru_utime) + block_ru_stime.labels(self.name).inc(usage.ru_stime) + block_db_txn_count.labels(self.name).inc(usage.db_txn_count) + block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) + block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) + except ValueError: + logger.warn( + "Failed to save metrics! OLD: %r, NEW: %r", + self.start_usage, current + ) if self.created_context: self.start_context.__exit__(exc_type, exc_val, exc_tb) diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py new file mode 100644 index 0000000000..4288312b8a --- /dev/null +++ b/synapse/util/module_loader.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + +from synapse.config._base import ConfigError + + +def load_module(provider): + """ Loads a module with its config + Take a dict with keys 'module' (the module name) and 'config' + (the config dict). + + Returns + Tuple of (provider class, parsed config object) + """ + # We need to import the module, and then pick the class out of + # that, so we split based on the last dot. + module, clz = provider['module'].rsplit(".", 1) + module = importlib.import_module(module) + provider_class = getattr(module, clz) + + try: + provider_config = provider_class.parse_config(provider["config"]) + except Exception as e: + raise ConfigError( + "Failed to parse config for %r: %r" % (provider['module'], e) + ) + + return provider_class, provider_config diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py index 607161e7f0..a6c30e5265 100644 --- a/synapse/util/msisdn.py +++ b/synapse/util/msisdn.py @@ -14,6 +14,7 @@ # limitations under the License. import phonenumbers + from synapse.api.errors import SynapseError diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 1101881a2d..7deb38f2a7 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -13,17 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from synapse.api.errors import LimitExceededError - -from synapse.util.async import sleep -from synapse.util.logcontext import preserve_fn - import collections import contextlib import logging +from twisted.internet import defer + +from synapse.api.errors import LimitExceededError +from synapse.util.logcontext import ( + PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, +) logger = logging.getLogger(__name__) @@ -91,13 +92,22 @@ class _PerHostRatelimiter(object): self.window_size = window_size self.sleep_limit = sleep_limit - self.sleep_msec = sleep_msec + self.sleep_sec = sleep_msec / 1000.0 self.reject_limit = reject_limit self.concurrent_requests = concurrent_requests + # request_id objects for requests which have been slept self.sleeping_requests = set() + + # map from request_id object to Deferred for requests which are ready + # for processing but have been queued self.ready_request_queue = collections.OrderedDict() + + # request id objects for requests which are in progress self.current_processing = set() + + # times at which we have recently (within the last window_size ms) + # received requests. self.request_times = [] @contextlib.contextmanager @@ -116,11 +126,15 @@ class _PerHostRatelimiter(object): def _on_enter(self, request_id): time_now = self.clock.time_msec() + + # remove any entries from request_times which aren't within the window self.request_times[:] = [ r for r in self.request_times if time_now - r < self.window_size ] + # reject the request if we already have too many queued up (either + # sleeping or in the ready queue). queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) if queue_size > self.reject_limit: raise LimitExceededError( @@ -133,9 +147,13 @@ class _PerHostRatelimiter(object): def queue_request(): if len(self.current_processing) > self.concurrent_requests: - logger.debug("Ratelimit [%s]: Queue req", id(request_id)) queue_defer = defer.Deferred() self.ready_request_queue[request_id] = queue_defer + logger.info( + "Ratelimiter: queueing request (queue now %i items)", + len(self.ready_request_queue), + ) + return queue_defer else: return defer.succeed(None) @@ -147,10 +165,9 @@ class _PerHostRatelimiter(object): if len(self.request_times) > self.sleep_limit: logger.debug( - "Ratelimit [%s]: sleeping req", - id(request_id), + "Ratelimiter: sleeping request for %f sec", self.sleep_sec, ) - ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0) + ret_defer = run_in_background(self.clock.sleep, self.sleep_sec) self.sleeping_requests.add(request_id) @@ -176,6 +193,9 @@ class _PerHostRatelimiter(object): return r def on_err(r): + # XXX: why is this necessary? this is called before we start + # processing the request so why would the request be in + # current_processing? self.current_processing.discard(request_id) return r @@ -187,7 +207,7 @@ class _PerHostRatelimiter(object): ret_defer.addCallbacks(on_start, on_err) ret_defer.addBoth(on_both) - return ret_defer + return make_deferred_yieldable(ret_defer) def _on_exit(self, request_id): logger.debug( @@ -196,8 +216,10 @@ class _PerHostRatelimiter(object): ) self.current_processing.discard(request_id) try: - request_id, deferred = self.ready_request_queue.popitem() - self.current_processing.add(request_id) - deferred.callback(None) + # start processing the next item on the queue. + _, deferred = self.ready_request_queue.popitem(last=False) + + with PreserveLoggingContext(): + deferred.callback(None) except KeyError: pass diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 4fa9d1a03c..8a3a06fd74 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -12,20 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.util.logcontext -from twisted.internet import defer - -from synapse.api.errors import CodeMessageException - import logging import random +from twisted.internet import defer + +import synapse.util.logcontext +from synapse.api.errors import CodeMessageException logger = logging.getLogger(__name__) class NotRetryingDestination(Exception): def __init__(self, retry_last_ts, retry_interval, destination): + """Raised by the limiter (and federation client) to indicate that we are + are deliberately not attempting to contact a given server. + + Args: + retry_last_ts (int): the unix ts in milliseconds of our last attempt + to contact the server. 0 indicates that the last attempt was + successful or that we've never actually attempted to connect. + retry_interval (int): the time in milliseconds to wait until the next + attempt. + destination (str): the domain in question + """ + msg = "Not retrying server %s." % (destination,) super(NotRetryingDestination, self).__init__(msg) @@ -189,10 +200,10 @@ class RetryDestinationLimiter(object): yield self.store.set_destination_retry_timings( self.destination, retry_last_ts, self.retry_interval ) - except: + except Exception: logger.exception( - "Failed to store set_destination_retry_timings", + "Failed to store destination_retry_timings", ) # we deliberately do this in the background. - synapse.util.logcontext.preserve_fn(store_retry_timings)() + synapse.util.logcontext.run_in_background(store_retry_timings) diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py index f4a9abf83f..6c0f2bb0cf 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py @@ -13,9 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import resource import logging - +import resource logger = logging.getLogger("synapse.app.homeserver") diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 95a6168e16..43d9db67ec 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -16,18 +16,20 @@ import random import string +from six.moves import range + _string_with_symbols = ( string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" ) def random_string(length): - return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) + return ''.join(random.choice(string.ascii_letters) for _ in range(length)) def random_string_with_symbols(length): return ''.join( - random.choice(_string_with_symbols) for _ in xrange(length) + random.choice(_string_with_symbols) for _ in range(length) ) diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py new file mode 100644 index 0000000000..75efa0117b --- /dev/null +++ b/synapse/util/threepids.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +logger = logging.getLogger(__name__) + + +def check_3pid_allowed(hs, medium, address): + """Checks whether a given format of 3PID is allowed to be used on this HS + + Args: + hs (synapse.server.HomeServer): server + medium (str): 3pid medium - e.g. email, msisdn + address (str): address within that medium (e.g. "wotan@matrix.org") + msisdns need to first have been canonicalised + Returns: + bool: whether the 3PID medium/address is allowed to be added to this HS + """ + + if hs.config.allowed_local_3pids: + for constraint in hs.config.allowed_local_3pids: + logger.debug( + "Checking 3PID %s (%s) against %s (%s)", + address, medium, constraint['pattern'], constraint['medium'], + ) + if ( + medium == constraint['medium'] and + re.match(constraint['pattern'], address) + ): + return True + else: + return True + + return False diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index 52086df465..1fbcd41115 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -14,9 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import os import logging +import os +import subprocess logger = logging.getLogger(__name__) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 7412fc57a4..7a9e45aca9 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six.moves import range + class _Entry(object): __slots__ = ["end_key", "queue"] @@ -68,7 +70,7 @@ class WheelTimer(object): # Add empty entries between the end of the current list and when we want # to insert. This ensures there are no gaps. self.entries.extend( - _Entry(key) for key in xrange(last_key, then_key + 1) + _Entry(key) for key in range(last_key, then_key + 1) ) self.entries[-1].queue.append(obj) @@ -91,7 +93,4 @@ class WheelTimer(object): return ret def __len__(self): - l = 0 - for entry in self.entries: - l += len(entry.queue) - return l + return sum(len(entry.queue) for entry in self.entries) |