summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py58
-rw-r--r--synapse/util/async_helpers.py115
-rw-r--r--synapse/util/caches/__init__.py28
-rw-r--r--synapse/util/caches/descriptors.py276
-rw-r--r--synapse/util/caches/dictionary_cache.py15
-rw-r--r--synapse/util/caches/expiringcache.py18
-rw-r--r--synapse/util/caches/lrucache.py14
-rw-r--r--synapse/util/caches/response_cache.py30
-rw-r--r--synapse/util/caches/snapshot_cache.py94
-rw-r--r--synapse/util/caches/stream_change_cache.py13
-rw-r--r--synapse/util/caches/treecache.py5
-rw-r--r--synapse/util/caches/ttlcache.py9
-rw-r--r--synapse/util/distributor.py31
-rw-r--r--synapse/util/file_consumer.py2
-rw-r--r--synapse/util/frozenutils.py11
-rw-r--r--synapse/util/hash.py33
-rw-r--r--synapse/util/httpresourcetree.py10
-rw-r--r--synapse/util/iterutils.py48
-rw-r--r--synapse/util/jsonobject.py6
-rw-r--r--synapse/util/logcontext.py653
-rw-r--r--synapse/util/logformatter.py43
-rw-r--r--synapse/util/logutils.py209
-rw-r--r--synapse/util/manhole.py18
-rw-r--r--synapse/util/metrics.py111
-rw-r--r--synapse/util/module_loader.py28
-rw-r--r--synapse/util/msisdn.py6
-rw-r--r--synapse/util/patch_inline_callbacks.py219
-rw-r--r--synapse/util/ratelimitutils.py51
-rw-r--r--synapse/util/retryutils.py68
-rw-r--r--synapse/util/rlimit.py2
-rw-r--r--synapse/util/stringutils.py23
-rw-r--r--synapse/util/threepids.py25
-rw-r--r--synapse/util/versionstring.py92
-rw-r--r--synapse/util/wheel_timer.py4
34 files changed, 950 insertions, 1418 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py

index 8f5a526800..60f0de70f7 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py
@@ -15,13 +15,12 @@ import logging import re -from itertools import islice import attr from twisted.internet import defer, task -from synapse.util.logcontext import PreserveLoggingContext +from synapse.logging import context logger = logging.getLogger(__name__) @@ -40,15 +39,16 @@ class Clock(object): Args: reactor: The Twisted reactor to use. """ + _reactor = attr.ib() @defer.inlineCallbacks def sleep(self, seconds): d = defer.Deferred() - with PreserveLoggingContext(): + with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) res = yield d - defer.returnValue(res) + return res def time(self): """Returns the current system time in seconds since epoch.""" @@ -61,7 +61,10 @@ class Clock(object): def looping_call(self, f, msec, *args, **kwargs): """Call a function repeatedly. - Waits `msec` initially before calling `f` for the first time. + Waits `msec` initially before calling `f` for the first time. + + Note that the function will be called with no logcontext, so if it is anything + other than trivial, you probably want to wrap it in run_as_background_process. Args: f(function): The function to call repeatedly. @@ -72,25 +75,27 @@ class Clock(object): call = task.LoopingCall(f, *args, **kwargs) call.clock = self._reactor d = call.start(msec / 1000.0, now=False) - d.addErrback( - log_failure, "Looping call died", consumeErrors=False, - ) + d.addErrback(log_failure, "Looping call died", consumeErrors=False) return call def call_later(self, delay, callback, *args, **kwargs): """Call something later + Note that the function will be called with no logcontext, so if it is anything + other than trivial, you probably want to wrap it in run_as_background_process. + Args: delay(float): How long to wait in seconds. callback(function): Function to call *args: Postional arguments to pass to function. **kwargs: Key arguments to pass to function. """ + def wrapped_callback(*args, **kwargs): - with PreserveLoggingContext(): + with context.PreserveLoggingContext(): callback(*args, **kwargs) - with PreserveLoggingContext(): + with context.PreserveLoggingContext(): return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) def cancel_call_later(self, timer, ignore_errs=False): @@ -101,22 +106,6 @@ class Clock(object): raise -def batch_iter(iterable, size): - """batch an iterable up into tuples with a maximum size - - Args: - iterable (iterable): the iterable to slice - size (int): the maximum batch size - - Returns: - an iterator over the chunks - """ - # make sure we can deal with iterables like lists too - sourceiter = iter(iterable) - # call islice until it returns an empty tuple - return iter(lambda: tuple(islice(sourceiter, size)), ()) - - def log_failure(failure, msg, consumeErrors=True): """Creates a function suitable for passing to `Deferred.addErrback` that logs any failures that occur. @@ -131,12 +120,7 @@ def log_failure(failure, msg, consumeErrors=True): """ logger.error( - msg, - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject() - ) + msg, exc_info=(failure.type, failure.value, failure.getTracebackObject()) ) if not consumeErrors: @@ -154,12 +138,12 @@ def glob_to_regex(glob): Returns: re.RegexObject """ - res = '' + res = "" for c in glob: - if c == '*': - res = res + '.*' - elif c == '?': - res = res + '.' + if c == "*": + res = res + ".*" + elif c == "?": + res = res + "." else: res = res + re.escape(c) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 7253ba120f..581dffd8a0 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -13,23 +13,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import collections import logging from contextlib import contextmanager +from typing import Dict, Sequence, Set, Union from six.moves import range +import attr + from twisted.internet import defer from twisted.internet.defer import CancelledError from twisted.python import failure -from synapse.util import Clock, logcontext, unwrapFirstError - -from .logcontext import ( +from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, run_in_background, ) +from synapse.util import Clock, unwrapFirstError logger = logging.getLogger(__name__) @@ -70,6 +73,10 @@ class ObservableDeferred(object): def errback(f): object.__setattr__(self, "_result", (False, f)) while self._observers: + # This is a little bit of magic to correctly propagate stack + # traces when we `await` on one of the observer deferreds. + f.value.__failure__ = f + try: # TODO: Handle errors here. self._observers.pop().errback(f) @@ -83,11 +90,12 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) - def observe(self): + def observe(self) -> defer.Deferred: """Observe the underlying deferred. - Can return either a deferred if the underlying deferred is still pending - (or has failed), or the actual value. Callers may need to use maybeDeferred. + This returns a brand new deferred that is resolved when the underlying + deferred is resolved. Interacting with the returned deferred does not + effect the underdlying deferred. """ if not self._result: d = defer.Deferred() @@ -95,13 +103,14 @@ class ObservableDeferred(object): def remove(r): self._observers.discard(d) return r + d.addBoth(remove) self._observers.add(d) return d else: success, res = self._result - return res if success else defer.fail(res) + return defer.succeed(res) if success else defer.fail(res) def observers(self): return self._observers @@ -123,7 +132,9 @@ class ObservableDeferred(object): def __repr__(self): return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( - id(self), self._result, self._deferred, + id(self), + self._result, + self._deferred, ) @@ -132,9 +143,9 @@ def concurrently_execute(func, args, limit): the number of concurrent executions. Args: - func (func): Function to execute, should return a deferred. - args (list): List of arguments to pass to func, each invocation of func - gets a signle argument. + func (func): Function to execute, should return a deferred or coroutine. + args (Iterable): List of arguments to pass to func, each invocation of func + gets a single argument. limit (int): Maximum number of conccurent executions. Returns: @@ -142,18 +153,19 @@ def concurrently_execute(func, args, limit): """ it = iter(args) - @defer.inlineCallbacks - def _concurrently_execute_inner(): + async def _concurrently_execute_inner(): try: while True: - yield func(next(it)) + await maybe_awaitable(func(next(it))) except StopIteration: pass - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(_concurrently_execute_inner) - for _ in range(limit) - ], consumeErrors=True)).addErrback(unwrapFirstError) + return make_deferred_yieldable( + defer.gatherResults( + [run_in_background(_concurrently_execute_inner) for _ in range(limit)], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) def yieldable_gather_results(func, iter, *args, **kwargs): @@ -169,10 +181,12 @@ def yieldable_gather_results(func, iter, *args, **kwargs): Deferred[list]: Resolved when all functions have been invoked, or errors if one of the function calls fails. """ - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(func, item, *args, **kwargs) - for item in iter - ], consumeErrors=True)).addErrback(unwrapFirstError) + return make_deferred_yieldable( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) class Linearizer(object): @@ -185,6 +199,7 @@ class Linearizer(object): # do some work. """ + def __init__(self, name=None, max_count=1, clock=None): """ Args: @@ -197,6 +212,7 @@ class Linearizer(object): if not clock: from twisted.internet import reactor + clock = Clock(reactor) self._clock = clock self.max_count = max_count @@ -205,7 +221,9 @@ class Linearizer(object): # the first element is the number of things executing, and # the second element is an OrderedDict, where the keys are deferreds for the # things blocked from executing. - self.key_to_defer = {} + self.key_to_defer = ( + {} + ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]] def queue(self, key): # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. @@ -221,7 +239,7 @@ class Linearizer(object): res = self._await_lock(key) else: logger.debug( - "Acquired uncontended linearizer lock %r for key %r", self.name, key, + "Acquired uncontended linearizer lock %r for key %r", self.name, key ) entry[0] += 1 res = defer.succeed(None) @@ -266,9 +284,7 @@ class Linearizer(object): """ entry = self.key_to_defer[key] - logger.debug( - "Waiting to acquire linearizer lock %r for key %r", self.name, key, - ) + logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) new_defer = make_deferred_yieldable(defer.Deferred()) entry[1][new_defer] = 1 @@ -293,14 +309,14 @@ class Linearizer(object): logger.info("defer %r got err %r", new_defer, e) if isinstance(e, CancelledError): logger.debug( - "Cancelling wait for linearizer lock %r for key %r", - self.name, key, + "Cancelling wait for linearizer lock %r for key %r", self.name, key ) else: - logger.warn( + logger.warning( "Unexpected exception waiting for linearizer lock %r for key %r", - self.name, key, + self.name, + key, ) # we just have to take ourselves back out of the queue. @@ -334,10 +350,10 @@ class ReadWriteLock(object): def __init__(self): # Latest readers queued - self.key_to_current_readers = {} + self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]] # Latest writer queued - self.key_to_current_writer = {} + self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] @defer.inlineCallbacks def read(self, key): @@ -360,7 +376,7 @@ class ReadWriteLock(object): new_defer.callback(None) self.key_to_current_readers.get(key, set()).discard(new_defer) - defer.returnValue(_ctx_manager()) + return _ctx_manager() @defer.inlineCallbacks def write(self, key): @@ -390,7 +406,7 @@ class ReadWriteLock(object): if self.key_to_current_writer[key] == new_defer: self.key_to_current_writer.pop(key) - defer.returnValue(_ctx_manager()) + return _ctx_manager() def _cancelled_to_timed_out_error(value, timeout): @@ -438,7 +454,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): try: deferred.cancel() - except: # noqa: E722, if we throw any exception it'll break time outs + except: # noqa: E722, if we throw any exception it'll break time outs logger.exception("Canceller failed during timeout") if not new_d.called: @@ -473,3 +489,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): deferred.addCallbacks(success_cb, failure_cb) return new_d + + +@attr.s(slots=True, frozen=True) +class DoneAwaitable(object): + """Simple awaitable that returns the provided value. + """ + + value = attr.ib() + + def __await__(self): + return self + + def __iter__(self): + return self + + def __next__(self): + raise StopIteration(self.value) + + +def maybe_awaitable(value): + """Convert a value to an awaitable if not already an awaitable. + """ + + if hasattr(value, "__await__"): + return value + + return DoneAwaitable(value) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index f37d5bec08..da5077b471 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +16,7 @@ import logging import os +from typing import Dict import six from six.moves import intern @@ -36,7 +38,7 @@ def get_cache_factor_for(cache_name): caches_by_name = {} -collectors_by_name = {} +collectors_by_name = {} # type: Dict cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) @@ -51,7 +53,19 @@ response_cache_evicted = Gauge( response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) -def register_cache(cache_type, cache_name, cache): +def register_cache(cache_type, cache_name, cache, collect_callback=None): + """Register a cache object for metric collection. + + Args: + cache_type (str): + cache_name (str): name of the cache + cache (object): cache itself + collect_callback (callable|None): if not None, a function which is called during + metric collection to update additional metrics. + + Returns: + CacheMetric: an object which provides inc_{hits,misses,evictions} methods + """ # Check if the metric is already registered. Unregister it, if so. # This usually happens during tests, as at runtime these caches are @@ -90,8 +104,10 @@ def register_cache(cache_type, cache_name, cache): cache_hits.labels(cache_name).set(self.hits) cache_evicted.labels(cache_name).set(self.evicted_size) cache_total.labels(cache_name).set(self.hits + self.misses) + if collect_callback: + collect_callback() except Exception as e: - logger.warn("Error calculating metrics for %s: %s", cache_name, e) + logger.warning("Error calculating metrics for %s: %s", cache_name, e) raise yield GaugeMetricFamily("__unused", "") @@ -104,8 +120,8 @@ def register_cache(cache_type, cache_name, cache): KNOWN_KEYS = { - key: key for key in - ( + key: key + for key in ( "auth_events", "content", "depth", @@ -150,7 +166,7 @@ def intern_dict(dictionary): def _intern_known_values(key, value): - intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",) + intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key") if key in intern_keys: return intern_string(value) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 187510576a..2e8f6543e5 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py
@@ -17,32 +17,53 @@ import functools import inspect import logging import threading -from collections import namedtuple +from typing import Any, Tuple, Union, cast +from weakref import WeakValueDictionary -import six -from six import itervalues, string_types +from six import itervalues + +from prometheus_client import Gauge +from typing_extensions import Protocol from twisted.internet import defer -from synapse.util import logcontext, unwrapFirstError +from synapse.logging.context import make_deferred_yieldable, preserve_fn +from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry -from synapse.util.stringutils import to_ascii from . import register_cache logger = logging.getLogger(__name__) +CacheKey = Union[Tuple, Any] + + +class _CachedFunction(Protocol): + invalidate = None # type: Any + invalidate_all = None # type: Any + invalidate_many = None # type: Any + prefill = None # type: Any + cache = None # type: Any + num_args = None # type: Any + + def __name__(self): + ... + + +cache_pending_metric = Gauge( + "synapse_util_caches_cache_pending", + "Number of lookups currently pending for this cache", + ["name"], +) _CacheSentinel = object() class CacheEntry(object): - __slots__ = [ - "deferred", "callbacks", "invalidated" - ] + __slots__ = ["deferred", "callbacks", "invalidated"] def __init__(self, deferred, callbacks): self.deferred = deferred @@ -73,7 +94,9 @@ class Cache(object): self._pending_deferred_cache = cache_type() self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type, + max_size=max_entries, + keylen=keylen, + cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, ) @@ -81,11 +104,19 @@ class Cache(object): self.name = name self.keylen = keylen self.thread = None - self.metrics = register_cache("cache", name, self.cache) + self.metrics = register_cache( + "cache", + name, + self.cache, + collect_callback=self._metrics_collection_callback, + ) def _on_evicted(self, evicted_count): self.metrics.inc_evictions(evicted_count) + def _metrics_collection_callback(self): + cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) + def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -107,7 +138,7 @@ class Cache(object): update_metrics (bool): whether to update the cache hit rate metrics Returns: - Either a Deferred or the raw result + Either an ObservableDeferred or the raw result """ callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _CacheSentinel) @@ -131,12 +162,14 @@ class Cache(object): return default def set(self, key, value, callback=None): + if not isinstance(value, defer.Deferred): + raise TypeError("not a Deferred") + callbacks = [callback] if callback else [] self.check_thread() - entry = CacheEntry( - deferred=value, - callbacks=callbacks, - ) + observable = ObservableDeferred(value, consumeErrors=True) + observer = defer.maybeDeferred(observable.observe) + entry = CacheEntry(deferred=observable, callbacks=callbacks) existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: @@ -144,20 +177,31 @@ class Cache(object): self._pending_deferred_cache[key] = entry - def shuffle(result): + def compare_and_pop(): + """Check if our entry is still the one in _pending_deferred_cache, and + if so, pop it. + + Returns true if the entries matched. + """ existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry is entry: + return True + + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + return False + + def cb(result): + if compare_and_pop(): self.cache.set(key, result, entry.callbacks) else: - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - # we're not going to put this entry into the cache, so need # to make sure that the invalidation callbacks are called. # That was probably done when _pending_deferred_cache was @@ -165,9 +209,16 @@ class Cache(object): # `invalidate` being previously called, in which case it may # not have been. Either way, let's double-check now. entry.invalidate() - return result - entry.deferred.addCallback(shuffle) + def eb(_fail): + compare_and_pop() + entry.invalidate() + + # once the deferred completes, we can move the entry from the + # _pending_deferred_cache to the real cache. + # + observer.addCallbacks(cb, eb) + return observable def prefill(self, key, value, callback=None): callbacks = [callback] if callback else [] @@ -191,9 +242,7 @@ class Cache(object): def invalidate_many(self, key): self.check_thread() if not isinstance(key, tuple): - raise TypeError( - "The cache key must be a tuple not %r" % (type(key),) - ) + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the @@ -212,7 +261,9 @@ class Cache(object): class _CacheDescriptorBase(object): - def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): + def __init__( + self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False + ): self.orig = orig if inlineCallbacks: @@ -220,7 +271,7 @@ class _CacheDescriptorBase(object): else: self.function_to_call = orig - arg_spec = inspect.getargspec(orig) + arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args if "cache_context" in all_args: @@ -244,29 +295,25 @@ class _CacheDescriptorBase(object): raise Exception( "Not enough explicit positional arguments to key off for %r: " "got %i args, but wanted %i. (@cached cannot key off *args or " - "**kwargs)" - % (orig.__name__, len(all_args), num_args) + "**kwargs)" % (orig.__name__, len(all_args), num_args) ) self.num_args = num_args # list of the names of the args used as the cache key - self.arg_names = all_args[1:num_args + 1] + self.arg_names = all_args[1 : num_args + 1] # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value if arg_spec.defaults: - self.arg_defaults = dict(zip( - all_args[-len(arg_spec.defaults):], - arg_spec.defaults - )) + self.arg_defaults = dict( + zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults) + ) else: self.arg_defaults = {} if "cache_context" in self.arg_names: - raise Exception( - "cache_context arg cannot be included among the cache keys" - ) + raise Exception("cache_context arg cannot be included among the cache keys") self.add_cache_context = cache_context @@ -297,19 +344,31 @@ class CacheDescriptor(_CacheDescriptorBase): def foo(self, key, cache_context): r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) - defer.returnValue(r1 + r2) + return r1 + r2 Args: num_args (int): number of positional arguments (excluding ``self`` and ``cache_context``) to use as cache keys. Defaults to all named args of the function. """ - def __init__(self, orig, max_entries=1000, num_args=None, tree=False, - inlineCallbacks=False, cache_context=False, iterable=False): + + def __init__( + self, + orig, + max_entries=1000, + num_args=None, + tree=False, + inlineCallbacks=False, + cache_context=False, + iterable=False, + ): super(CacheDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks, - cache_context=cache_context) + orig, + num_args=num_args, + inlineCallbacks=inlineCallbacks, + cache_context=cache_context, + ) max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) @@ -356,12 +415,14 @@ class CacheDescriptor(_CacheDescriptorBase): return args[0] else: return self.arg_defaults[nm] + else: + def get_cache_key(args, kwargs): return tuple(get_cache_key_gen(args, kwargs)) @functools.wraps(self.orig) - def wrapped(*args, **kwargs): + def _wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -371,7 +432,7 @@ class CacheDescriptor(_CacheDescriptorBase): # Add our own `cache_context` to argument list if the wrapped function # has asked for one if self.add_cache_context: - kwargs["cache_context"] = _CacheContext(cache, cache_key) + kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) @@ -379,12 +440,11 @@ class CacheDescriptor(_CacheDescriptorBase): if isinstance(cached_result_d, ObservableDeferred): observer = cached_result_d.observe() else: - observer = cached_result_d + observer = defer.succeed(cached_result_d) except KeyError: ret = defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), - obj, *args, **kwargs + preserve_fn(self.function_to_call), obj, *args, **kwargs ) def onErr(f): @@ -393,20 +453,12 @@ class CacheDescriptor(_CacheDescriptorBase): ret.addErrback(onErr) - # If our cache_key is a string on py2, try to convert to ascii - # to save a bit of space in large caches. Py3 does this - # internally automatically. - if six.PY2 and isinstance(cache_key, string_types): - cache_key = to_ascii(cache_key) - - result_d = ObservableDeferred(ret, consumeErrors=True) - cache.set(cache_key, result_d, callback=invalidate_callback) + result_d = cache.set(cache_key, ret, callback=invalidate_callback) observer = result_d.observe() - if isinstance(observer, defer.Deferred): - return logcontext.make_deferred_yieldable(observer) - else: - return observer + return make_deferred_yieldable(observer) + + wrapped = cast(_CachedFunction, _wrapped) if self.num_args == 1: wrapped.invalidate = lambda key: cache.invalidate(key[0]) @@ -432,13 +484,13 @@ class CacheListDescriptor(_CacheDescriptorBase): Given a list of keys it looks in the cache to find any hits, then passes the list of missing keys to the wrapped function. - Once wrapped, the function returns either a Deferred which resolves to - the list of results, or (if all results were cached), just the list of - results. + Once wrapped, the function returns a Deferred which resolves to the list + of results. """ - def __init__(self, orig, cached_method_name, list_name, num_args=None, - inlineCallbacks=False): + def __init__( + self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False + ): """ Args: orig (function) @@ -451,7 +503,8 @@ class CacheListDescriptor(_CacheDescriptorBase): be wrapped by defer.inlineCallbacks """ super(CacheListDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks) + orig, num_args=num_args, inlineCallbacks=inlineCallbacks + ) self.list_name = list_name @@ -463,7 +516,7 @@ class CacheListDescriptor(_CacheDescriptorBase): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cached_method_name,) + % (self.list_name, cached_method_name) ) def __get__(self, obj, objtype=None): @@ -494,8 +547,10 @@ class CacheListDescriptor(_CacheDescriptorBase): # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: + def arg_to_cache_key(arg): return arg + else: keylist = list(keyargs) @@ -505,8 +560,7 @@ class CacheListDescriptor(_CacheDescriptorBase): for arg in list_args: try: - res = cache.get(arg_to_cache_key(arg), - callback=invalidate_callback) + res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): @@ -519,7 +573,7 @@ class CacheListDescriptor(_CacheDescriptorBase): missing.add(arg) if missing: - # we need an observable deferred for each entry in the list, + # we need a deferred for each entry in the list, # which we put in the cache. Each deferred resolves with the # relevant result for that key. deferreds_map = {} @@ -527,8 +581,7 @@ class CacheListDescriptor(_CacheDescriptorBase): deferred = defer.Deferred() deferreds_map[arg] = deferred key = arg_to_cache_key(arg) - observable = ObservableDeferred(deferred) - cache.set(key, observable, callback=invalidate_callback) + cache.set(key, deferred, callback=invalidate_callback) def complete_all(res): # the wrapped function has completed. It returns a @@ -554,40 +607,62 @@ class CacheListDescriptor(_CacheDescriptorBase): args_to_call = dict(arg_dict) args_to_call[self.list_name] = list(missing) - cached_defers.append(defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), - **args_to_call - ).addCallbacks(complete_all, errback)) + cached_defers.append( + defer.maybeDeferred( + preserve_fn(self.function_to_call), **args_to_call + ).addCallbacks(complete_all, errback) + ) if cached_defers: - d = defer.gatherResults( - cached_defers, - consumeErrors=True, - ).addCallbacks( - lambda _: results, - unwrapFirstError + d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( + lambda _: results, unwrapFirstError ) - return logcontext.make_deferred_yieldable(d) + return make_deferred_yieldable(d) else: - return results + return defer.succeed(results) obj.__dict__[self.orig.__name__] = wrapped return wrapped -class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): - # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, - # which namedtuple does for us (i.e. two _CacheContext are the same if - # their caches and keys match). This is important in particular to - # dedupe when we add callbacks to lru cache nodes, otherwise the number - # of callbacks would grow. - def invalidate(self): - self.cache.invalidate(self.key) +class _CacheContext: + """Holds cache information from the cached function higher in the calling order. + + Can be used to invalidate the higher level cache entry if something changes + on a lower level. + """ + + _cache_context_objects = ( + WeakValueDictionary() + ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext] + + def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None + self._cache = cache + self._cache_key = cache_key + + def invalidate(self): # type: () -> None + """Invalidates the cache entry referred to by the context.""" + self._cache.invalidate(self._cache_key) + + @classmethod + def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext + """Returns an instance constructed with the given arguments. + + A new instance is only created if none already exists. + """ + + # We make sure there are no identical _CacheContext instances. This is + # important in particular to dedupe when we add callbacks to lru cache + # nodes, otherwise the number of callbacks would grow. + return cls._cache_context_objects.setdefault( + (cache, cache_key), cls(cache, cache_key) + ) -def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, - iterable=False): +def cached( + max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False +): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -598,8 +673,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, ) -def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False, - cache_context=False, iterable=False): +def cachedInlineCallbacks( + max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False +): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 6c0b5a4094..6834e6f3ae 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py
@@ -35,6 +35,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va there. value (dict): The full or partial dict value """ + def __len__(self): return len(self.value) @@ -84,13 +85,15 @@ class DictionaryCache(object): self.metrics.inc_hits() if dict_keys is None: - return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value)) + return DictionaryEntry( + entry.full, entry.known_absent, dict(entry.value) + ) else: - return DictionaryEntry(entry.full, entry.known_absent, { - k: entry.value[k] - for k in dict_keys - if k in entry.value - }) + return DictionaryEntry( + entry.full, + entry.known_absent, + {k: entry.value[k] for k in dict_keys if k in entry.value}, + ) self.metrics.inc_misses() return DictionaryEntry(False, set(), {}) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index f369780277..cddf1ed515 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py
@@ -28,8 +28,15 @@ SENTINEL = object() class ExpiringCache(object): - def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False, iterable=False): + def __init__( + self, + cache_name, + clock, + max_len=0, + expiry_ms=0, + reset_expiry_on_get=False, + iterable=False, + ): """ Args: cache_name (str): Name of this cache, used for logging. @@ -67,8 +74,7 @@ class ExpiringCache(object): def f(): return run_as_background_process( - "prune_cache_%s" % self._cache_name, - self._prune_cache, + "prune_cache_%s" % self._cache_name, self._prune_cache ) self._clock.looping_call(f, self._expiry_ms / 2) @@ -153,7 +159,9 @@ class ExpiringCache(object): logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self) + self._cache_name, + begin_length, + len(self), ) def __len__(self): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index b684f24e7b..1536cb64f3 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py
@@ -49,8 +49,15 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None, - evicted_callback=None): + + def __init__( + self, + max_size, + keylen=1, + cache_type=dict, + size_callback=None, + evicted_callback=None, + ): """ Args: max_size (int): @@ -93,9 +100,12 @@ class LruCache(object): cached_cache_len = [0] if size_callback is not None: + def cache_len(): return cached_cache_len[0] + else: + def cache_len(): return len(cache) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index afb03b2e1b..b68f9fe0d4 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py
@@ -16,9 +16,9 @@ import logging from twisted.internet import defer +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache -from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) @@ -35,12 +35,10 @@ class ResponseCache(object): self.pending_result_cache = {} # Requests that haven't finished yet. self.clock = hs.get_clock() - self.timeout_sec = timeout_ms / 1000. + self.timeout_sec = timeout_ms / 1000.0 self._name = name - self._metrics = register_cache( - "response_cache", name, self - ) + self._metrics = register_cache("response_cache", name, self) def size(self): return len(self.pending_result_cache) @@ -80,7 +78,7 @@ class ResponseCache(object): *deferred* should run its callbacks in the sentinel logcontext (ie, you should wrap normal synapse deferreds with - logcontext.run_in_background). + synapse.logging.context.run_in_background). Can return either a new Deferred (which also doesn't follow the synapse logcontext rules), or, if *deferred* was already complete, the actual @@ -100,8 +98,7 @@ class ResponseCache(object): def remove(r): if self.timeout_sec: self.clock.call_later( - self.timeout_sec, - self.pending_result_cache.pop, key, None, + self.timeout_sec, self.pending_result_cache.pop, key, None ) else: self.pending_result_cache.pop(key, None) @@ -124,7 +121,7 @@ class ResponseCache(object): @defer.inlineCallbacks def handle_request(request): # etc - defer.returnValue(result) + return result result = yield response_cache.wrap( key, @@ -140,21 +137,22 @@ class ResponseCache(object): *args: positional parameters to pass to the callback, if it is used - **kwargs: named paramters to pass to the callback, if it is used + **kwargs: named parameters to pass to the callback, if it is used Returns: twisted.internet.defer.Deferred: yieldable result """ result = self.get(key) if not result: - logger.info("[%s]: no cached result for [%s], calculating new one", - self._name, key) + logger.debug( + "[%s]: no cached result for [%s], calculating new one", self._name, key + ) d = run_in_background(callback, *args, **kwargs) result = self.set(key, d) elif not isinstance(result, defer.Deferred) or result.called: - logger.info("[%s]: using completed cached result for [%s]", - self._name, key) + logger.info("[%s]: using completed cached result for [%s]", self._name, key) else: - logger.info("[%s]: using incomplete cached result for [%s]", - self._name, key) + logger.info( + "[%s]: using incomplete cached result for [%s]", self._name, key + ) return make_deferred_yieldable(result) diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py deleted file mode 100644
index 8318db8d2c..0000000000 --- a/synapse/util/caches/snapshot_cache.py +++ /dev/null
@@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.async_helpers import ObservableDeferred - - -class SnapshotCache(object): - """Cache for snapshots like the response of /initialSync. - The response of initialSync only has to be a recent snapshot of the - server state. It shouldn't matter to clients if it is a few minutes out - of date. - - This caches a deferred response. Until the deferred completes it will be - returned from the cache. This means that if the client retries the request - while the response is still being computed, that original response will be - used rather than trying to compute a new response. - - Once the deferred completes it will removed from the cache after 5 minutes. - We delay removing it from the cache because a client retrying its request - could race with us finishing computing the response. - - Rather than tracking precisely how long something has been in the cache we - keep two generations of completed responses. Every 5 minutes discard the - old generation, move the new generation to the old generation, and set the - new generation to be empty. This means that a result will be in the cache - somewhere between 5 and 10 minutes. - """ - - DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes. - - def __init__(self): - self.pending_result_cache = {} # Request that haven't finished yet. - self.prev_result_cache = {} # The older requests that have finished. - self.next_result_cache = {} # The newer requests that have finished. - self.time_last_rotated_ms = 0 - - def rotate(self, time_now_ms): - # Rotate once if the cache duration has passed since the last rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms += self.DURATION_MS - - # Rotate again if the cache duration has passed twice since the last - # rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms = time_now_ms - - def get(self, time_now_ms, key): - self.rotate(time_now_ms) - # This cache is intended to deduplicate requests, so we expect it to be - # missed most of the time. So we just lookup the key in all of the - # dictionaries rather than trying to short circuit the lookup if the - # key is found. - result = self.prev_result_cache.get(key) - result = self.next_result_cache.get(key, result) - result = self.pending_result_cache.get(key, result) - if result is not None: - return result.observe() - else: - return None - - def set(self, time_now_ms, key, deferred): - self.rotate(time_now_ms) - - result = ObservableDeferred(deferred) - - self.pending_result_cache[key] = result - - def shuffle_along(r): - # When the deferred completes we shuffle it along to the first - # generation of the result cache. So that it will eventually - # expire from the rotation of that cache. - self.next_result_cache[key] = result - self.pending_result_cache.pop(key, None) - return r - - result.addBoth(shuffle_along) - - return result.observe() diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 625aedc940..235f64049c 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py
@@ -77,9 +77,8 @@ class StreamChangeCache(object): if stream_pos >= self._earliest_known_stream_pos: changed_entities = { - self._cache[k] for k in self._cache.islice( - start=self._cache.bisect_right(stream_pos), - ) + self._cache[k] + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) } result = changed_entities.intersection(entities) @@ -114,8 +113,10 @@ class StreamChangeCache(object): assert type(stream_pos) is int if stream_pos >= self._earliest_known_stream_pos: - return [self._cache[k] for k in self._cache.islice( - start=self._cache.bisect_right(stream_pos))] + return [ + self._cache[k] + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) + ] else: return None @@ -136,7 +137,7 @@ class StreamChangeCache(object): while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) self._earliest_known_stream_pos = max( - k, self._earliest_known_stream_pos, + k, self._earliest_known_stream_pos ) self._entity_to_key.pop(r, None) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index dd4c9e6067..2ea4e4e911 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py
@@ -1,3 +1,5 @@ +from typing import Dict + from six import itervalues SENTINEL = object() @@ -9,9 +11,10 @@ class TreeCache(object): efficiently. Keys must be tuples. """ + def __init__(self): self.size = 0 - self.root = {} + self.root = {} # type: Dict def __setitem__(self, key, value): return self.set(key, value) diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 5ba1862506..99646c7cf0 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py
@@ -55,7 +55,7 @@ class TTLCache(object): if e != SENTINEL: self._expiry_list.remove(e) - entry = _CacheEntry(expiry_time=expiry, key=key, value=value) + entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value) self._data[key] = entry self._expiry_list.add(entry) @@ -87,7 +87,8 @@ class TTLCache(object): key: key to look up Returns: - Tuple[Any, float]: the value from the cache, and the expiry time + Tuple[Any, float, float]: the value from the cache, the expiry time + and the TTL Raises: KeyError if the entry is not found @@ -99,7 +100,7 @@ class TTLCache(object): self._metrics.inc_misses() raise self._metrics.inc_hits() - return e.value, e.expiry_time + return e.value, e.expiry_time, e.ttl def pop(self, key, default=SENTINEL): """Remove a value from the cache @@ -155,7 +156,9 @@ class TTLCache(object): @attr.s(frozen=True, slots=True) class _CacheEntry(object): """TTLCache entry""" + # expiry_time is the first attribute, so that entries are sorted by expiry. expiry_time = attr.ib() + ttl = attr.ib() key = attr.ib() value = attr.ib() diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index e14c8bdfda..45af8d3eeb 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py
@@ -17,8 +17,8 @@ import logging from twisted.internet import defer +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) @@ -51,9 +51,7 @@ class Distributor(object): if name in self.signals: raise KeyError("%r already has a signal named %s" % (self, name)) - self.signals[name] = Signal( - name, - ) + self.signals[name] = Signal(name) if name in self.pre_registration: signal = self.signals[name] @@ -78,11 +76,7 @@ class Distributor(object): if name not in self.signals: raise KeyError("%r does not have a signal named %s" % (self, name)) - run_as_background_process( - name, - self.signals[name].fire, - *args, **kwargs - ) + run_as_background_process(name, self.signals[name].fire, *args, **kwargs) class Signal(object): @@ -118,22 +112,23 @@ class Signal(object): def eb(failure): logger.warning( "%s signal observer %s failed: %r", - self.name, observer, failure, + self.name, + observer, + failure, exc_info=( failure.type, failure.value, - failure.getTracebackObject())) + failure.getTracebackObject(), + ), + ) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) - deferreds = [ - run_in_background(do, o) - for o in self.observers - ] + deferreds = [run_in_background(do, o) for o in self.observers] - return make_deferred_yieldable(defer.gatherResults( - deferreds, consumeErrors=True, - )) + return make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) def __repr__(self): return "<Signal name=%r>" % (self.name,) diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 629ed44149..8b17d1c8b8 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py
@@ -17,7 +17,7 @@ from six.moves import queue from twisted.internet import threads -from synapse.util.logcontext import make_deferred_yieldable, run_in_background +from synapse.logging.context import make_deferred_yieldable, run_in_background class BackgroundFileConsumer(object): diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 014edea971..f2ccd5e7c6 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py
@@ -30,7 +30,7 @@ def freeze(o): return o try: - return tuple([freeze(i) for i in o]) + return tuple(freeze(i) for i in o) except TypeError: pass @@ -60,11 +60,10 @@ def _handle_frozendict(obj): # fishing the protected dict out of the object is a bit nasty, # but we don't really want the overhead of copying the dict. return obj._dict - raise TypeError('Object of type %s is not JSON serializable' % - obj.__class__.__name__) + raise TypeError( + "Object of type %s is not JSON serializable" % obj.__class__.__name__ + ) # A JSONEncoder which is capable of encoding frozendics without barfing -frozendict_json_encoder = json.JSONEncoder( - default=_handle_frozendict, -) +frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) diff --git a/synapse/util/hash.py b/synapse/util/hash.py new file mode 100644
index 0000000000..359168704e --- /dev/null +++ b/synapse/util/hash.py
@@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib + +import unpaddedbase64 + + +def sha256_and_url_safe_base64(input_text): + """SHA256 hash an input string, encode the digest as url-safe base64, and + return + + :param input_text: string to hash + :type input_text: str + + :returns a sha256 hashed and url-safe base64 encoded digest + :rtype: str + """ + digest = hashlib.sha256(input_text.encode()).digest() + return unpaddedbase64.encode_base64(digest, urlsafe=True) diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 2d7ddc1cbe..3c0e8469f3 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) def create_resource_tree(desired_tree, root_resource): - """Create the resource tree for this Home Server. + """Create the resource tree for this homeserver. This in unduly complicated because Twisted does not support putting child resources more than 1 level deep at a time. @@ -45,7 +45,7 @@ def create_resource_tree(desired_tree, root_resource): logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource - for path_seg in full_path.split(b'/')[1:-1]: + for path_seg in full_path.split(b"/")[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" child_resource = NoResource() @@ -60,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource): # =========================== # now attach the actual desired resource - last_path_seg = full_path.split(b'/')[-1] + last_path_seg = full_path.split(b"/")[-1] # if there is already a resource here, thieve its children and # replace it @@ -70,9 +70,7 @@ def create_resource_tree(desired_tree, root_resource): # to be replaced with the desired resource. existing_dummy_resource = resource_mappings[res_id] for child_name in existing_dummy_resource.listNames(): - child_res_id = _resource_id( - existing_dummy_resource, child_name - ) + child_res_id = _resource_id(existing_dummy_resource, child_name) child_resource = resource_mappings[child_res_id] # steal the children res.putChild(child_name, child_resource) diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py new file mode 100644
index 0000000000..06faeebe7f --- /dev/null +++ b/synapse/util/iterutils.py
@@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from itertools import islice +from typing import Iterable, Iterator, Sequence, Tuple, TypeVar + +T = TypeVar("T") + + +def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: + """batch an iterable up into tuples with a maximum size + + Args: + iterable (iterable): the iterable to slice + size (int): the maximum batch size + + Returns: + an iterator over the chunks + """ + # make sure we can deal with iterables like lists too + sourceiter = iter(iterable) + # call islice until it returns an empty tuple + return iter(lambda: tuple(islice(sourceiter, size)), ()) + + +ISeq = TypeVar("ISeq", bound=Sequence, covariant=True) + + +def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: + """Split the given sequence into chunks of the given size + + The last chunk may be shorter than the given size. + + If the input is empty, no chunks are returned. + """ + return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen)) diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index d668e5a6b8..6dce03dd3a 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py
@@ -70,7 +70,8 @@ class JsonEncodedObject(object): dict """ d = { - k: _encode(v) for (k, v) in self.__dict__.items() + k: _encode(v) + for (k, v) in self.__dict__.items() if k in self.valid_keys and k not in self.internal_keys } d.update(self.unrecognized_keys) @@ -78,7 +79,8 @@ class JsonEncodedObject(object): def get_internal_dict(self): d = { - k: _encode(v, internal=True) for (k, v) in self.__dict__.items() + k: _encode(v, internal=True) + for (k, v) in self.__dict__.items() if k in self.valid_keys } d.update(self.unrecognized_keys) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index fe412355d8..40e5c10a49 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py
@@ -1,4 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,633 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Thread-local-alike tracking of log contexts within synapse - -This module provides objects and utilities for tracking contexts through -synapse code, so that log lines can include a request identifier, and so that -CPU and database activity can be accounted for against the request that caused -them. - -See doc/log_contexts.rst for details on how this works. +""" +Backwards compatibility re-exports of ``synapse.logging.context`` functionality. """ -import logging -import threading - -from twisted.internet import defer, threads - -logger = logging.getLogger(__name__) - -try: - import resource - - # Python doesn't ship with a definition of RUSAGE_THREAD but it's defined - # to be 1 on linux so we hard code it. - RUSAGE_THREAD = 1 - - # If the system doesn't support RUSAGE_THREAD then this should throw an - # exception. - resource.getrusage(RUSAGE_THREAD) - - def get_thread_resource_usage(): - return resource.getrusage(RUSAGE_THREAD) -except Exception: - # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we - # won't track resource usage by returning None. - def get_thread_resource_usage(): - return None - - -class ContextResourceUsage(object): - """Object for tracking the resources used by a log context - - Attributes: - ru_utime (float): user CPU time (in seconds) - ru_stime (float): system CPU time (in seconds) - db_txn_count (int): number of database transactions done - db_sched_duration_sec (float): amount of time spent waiting for a - database connection - db_txn_duration_sec (float): amount of time spent doing database - transactions (excluding scheduling time) - evt_db_fetch_count (int): number of events requested from the database - """ - - __slots__ = [ - "ru_stime", "ru_utime", - "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec", - "evt_db_fetch_count", - ] - - def __init__(self, copy_from=None): - """Create a new ContextResourceUsage - - Args: - copy_from (ContextResourceUsage|None): if not None, an object to - copy stats from - """ - if copy_from is None: - self.reset() - else: - self.ru_utime = copy_from.ru_utime - self.ru_stime = copy_from.ru_stime - self.db_txn_count = copy_from.db_txn_count - - self.db_txn_duration_sec = copy_from.db_txn_duration_sec - self.db_sched_duration_sec = copy_from.db_sched_duration_sec - self.evt_db_fetch_count = copy_from.evt_db_fetch_count - - def copy(self): - return ContextResourceUsage(copy_from=self) - - def reset(self): - self.ru_stime = 0. - self.ru_utime = 0. - self.db_txn_count = 0 - - self.db_txn_duration_sec = 0 - self.db_sched_duration_sec = 0 - self.evt_db_fetch_count = 0 - - def __repr__(self): - return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', " - "db_txn_count='%r', db_txn_duration_sec='%r', " - "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % ( - self.ru_stime, - self.ru_utime, - self.db_txn_count, - self.db_txn_duration_sec, - self.db_sched_duration_sec, - self.evt_db_fetch_count,) - - def __iadd__(self, other): - """Add another ContextResourceUsage's stats to this one's. - - Args: - other (ContextResourceUsage): the other resource usage object - """ - self.ru_utime += other.ru_utime - self.ru_stime += other.ru_stime - self.db_txn_count += other.db_txn_count - self.db_txn_duration_sec += other.db_txn_duration_sec - self.db_sched_duration_sec += other.db_sched_duration_sec - self.evt_db_fetch_count += other.evt_db_fetch_count - return self - - def __isub__(self, other): - self.ru_utime -= other.ru_utime - self.ru_stime -= other.ru_stime - self.db_txn_count -= other.db_txn_count - self.db_txn_duration_sec -= other.db_txn_duration_sec - self.db_sched_duration_sec -= other.db_sched_duration_sec - self.evt_db_fetch_count -= other.evt_db_fetch_count - return self - - def __add__(self, other): - res = ContextResourceUsage(copy_from=self) - res += other - return res - - def __sub__(self, other): - res = ContextResourceUsage(copy_from=self) - res -= other - return res - - -class LoggingContext(object): - """Additional context for log formatting. Contexts are scoped within a - "with" block. - - If a parent is given when creating a new context, then: - - logging fields are copied from the parent to the new context on entry - - when the new context exits, the cpu usage stats are copied from the - child to the parent - - Args: - name (str): Name for the context for debugging. - parent_context (LoggingContext|None): The parent of the new context - """ - - __slots__ = [ - "previous_context", "name", "parent_context", - "_resource_usage", - "usage_start", - "main_thread", "alive", - "request", "tag", - ] - - thread_local = threading.local() - - class Sentinel(object): - """Sentinel to represent the root context""" - - __slots__ = [] - - def __str__(self): - return "sentinel" - - def copy_to(self, record): - pass - - def start(self): - pass - - def stop(self): - pass - - def add_database_transaction(self, duration_sec): - pass - - def add_database_scheduled(self, sched_sec): - pass - - def record_event_fetch(self, event_count): - pass - - def __nonzero__(self): - return False - __bool__ = __nonzero__ # python3 - - sentinel = Sentinel() - - def __init__(self, name=None, parent_context=None, request=None): - self.previous_context = LoggingContext.current_context() - self.name = name - - # track the resources used by this context so far - self._resource_usage = ContextResourceUsage() - - # If alive has the thread resource usage when the logcontext last - # became active. - self.usage_start = None - - self.main_thread = threading.current_thread() - self.request = None - self.tag = "" - self.alive = True - - self.parent_context = parent_context - - if self.parent_context is not None: - self.parent_context.copy_to(self) - - if request is not None: - # the request param overrides the request from the parent context - self.request = request - - def __str__(self): - if self.request: - return str(self.request) - return "%s@%x" % (self.name, id(self)) - - @classmethod - def current_context(cls): - """Get the current logging context from thread local storage - - Returns: - LoggingContext: the current logging context - """ - return getattr(cls.thread_local, "current_context", cls.sentinel) - - @classmethod - def set_current_context(cls, context): - """Set the current logging context in thread local storage - Args: - context(LoggingContext): The context to activate. - Returns: - The context that was previously active - """ - current = cls.current_context() - - if current is not context: - current.stop() - cls.thread_local.current_context = context - context.start() - return current - - def __enter__(self): - """Enters this logging context into thread local storage""" - old_context = self.set_current_context(self) - if self.previous_context != old_context: - logger.warn( - "Expected previous context %r, found %r", - self.previous_context, old_context - ) - self.alive = True - - return self - - def __exit__(self, type, value, traceback): - """Restore the logging context in thread local storage to the state it - was before this context was entered. - Returns: - None to avoid suppressing any exceptions that were thrown. - """ - current = self.set_current_context(self.previous_context) - if current is not self: - if current is self.sentinel: - logger.warning("Expected logging context %s was lost", self) - else: - logger.warning( - "Expected logging context %s but found %s", self, current - ) - self.previous_context = None - self.alive = False - - # if we have a parent, pass our CPU usage stats on - if ( - self.parent_context is not None - and hasattr(self.parent_context, '_resource_usage') - ): - self.parent_context._resource_usage += self._resource_usage - - # reset them in case we get entered again - self._resource_usage.reset() - - def copy_to(self, record): - """Copy logging fields from this context to a log record or - another LoggingContext - """ - - # 'request' is the only field we currently use in the logger, so that's - # all we need to copy - record.request = self.request - - def start(self): - if threading.current_thread() is not self.main_thread: - logger.warning("Started logcontext %s on different thread", self) - return - - # If we haven't already started record the thread resource usage so - # far - if not self.usage_start: - self.usage_start = get_thread_resource_usage() - - def stop(self): - if threading.current_thread() is not self.main_thread: - logger.warning("Stopped logcontext %s on different thread", self) - return - - # When we stop, let's record the cpu used since we started - if not self.usage_start: - logger.warning( - "Called stop on logcontext %s without calling start", self, - ) - return - - usage_end = get_thread_resource_usage() - - self._resource_usage.ru_utime += usage_end.ru_utime - self.usage_start.ru_utime - self._resource_usage.ru_stime += usage_end.ru_stime - self.usage_start.ru_stime - - self.usage_start = None - - def get_resource_usage(self): - """Get resources used by this logcontext so far. - - Returns: - ContextResourceUsage: a *copy* of the object tracking resource - usage so far - """ - # we always return a copy, for consistency - res = self._resource_usage.copy() - - # If we are on the correct thread and we're currently running then we - # can include resource usage so far. - is_main_thread = threading.current_thread() is self.main_thread - if self.alive and self.usage_start and is_main_thread: - current = get_thread_resource_usage() - res.ru_utime += current.ru_utime - self.usage_start.ru_utime - res.ru_stime += current.ru_stime - self.usage_start.ru_stime - - return res - - def add_database_transaction(self, duration_sec): - self._resource_usage.db_txn_count += 1 - self._resource_usage.db_txn_duration_sec += duration_sec - - def add_database_scheduled(self, sched_sec): - """Record a use of the database pool - - Args: - sched_sec (float): number of seconds it took us to get a - connection - """ - self._resource_usage.db_sched_duration_sec += sched_sec - - def record_event_fetch(self, event_count): - """Record a number of events being fetched from the db - - Args: - event_count (int): number of events being fetched - """ - self._resource_usage.evt_db_fetch_count += event_count - - -class LoggingContextFilter(logging.Filter): - """Logging filter that adds values from the current logging context to each - record. - Args: - **defaults: Default values to avoid formatters complaining about - missing fields - """ - def __init__(self, **defaults): - self.defaults = defaults - - def filter(self, record): - """Add each fields from the logging contexts to the record. - Returns: - True to include the record in the log output. - """ - context = LoggingContext.current_context() - for key, value in self.defaults.items(): - setattr(record, key, value) - - # context should never be None, but if it somehow ends up being, then - # we end up in a death spiral of infinite loops, so let's check, for - # robustness' sake. - if context is not None: - context.copy_to(record) - - return True - - -class PreserveLoggingContext(object): - """Captures the current logging context and restores it when the scope is - exited. Used to restore the context after a function using - @defer.inlineCallbacks is resumed by a callback from the reactor.""" - - __slots__ = ["current_context", "new_context", "has_parent"] - - def __init__(self, new_context=None): - if new_context is None: - new_context = LoggingContext.sentinel - self.new_context = new_context - - def __enter__(self): - """Captures the current logging context""" - self.current_context = LoggingContext.set_current_context( - self.new_context - ) - - if self.current_context: - self.has_parent = self.current_context.previous_context is not None - if not self.current_context.alive: - logger.debug( - "Entering dead context: %s", - self.current_context, - ) - - def __exit__(self, type, value, traceback): - """Restores the current logging context""" - context = LoggingContext.set_current_context(self.current_context) - - if context != self.new_context: - if context is LoggingContext.sentinel: - logger.warning("Expected logging context %s was lost", self.new_context) - else: - logger.warning( - "Expected logging context %s but found %s", - self.new_context, - context, - ) - - if self.current_context is not LoggingContext.sentinel: - if not self.current_context.alive: - logger.debug( - "Restoring dead context: %s", - self.current_context, - ) - - -def nested_logging_context(suffix, parent_context=None): - """Creates a new logging context as a child of another. - - The nested logging context will have a 'request' made up of the parent context's - request, plus the given suffix. - - CPU/db usage stats will be added to the parent context's on exit. - - Normal usage looks like: - - with nested_logging_context(suffix): - # ... do stuff - - Args: - suffix (str): suffix to add to the parent context's 'request'. - parent_context (LoggingContext|None): parent context. Will use the current context - if None. - - Returns: - LoggingContext: new logging context. - """ - if parent_context is None: - parent_context = LoggingContext.current_context() - return LoggingContext( - parent_context=parent_context, - request=parent_context.request + "-" + suffix, - ) - - -def preserve_fn(f): - """Function decorator which wraps the function with run_in_background""" - def g(*args, **kwargs): - return run_in_background(f, *args, **kwargs) - return g - - -def run_in_background(f, *args, **kwargs): - """Calls a function, ensuring that the current context is restored after - return from the function, and that the sentinel context is set once the - deferred returned by the function completes. - - Useful for wrapping functions that return a deferred which you don't yield - on (for instance because you want to pass it to deferred.gatherResults()). - - Note that if you completely discard the result, you should make sure that - `f` doesn't raise any deferred exceptions, otherwise a scary-looking - CRITICAL error about an unhandled error will be logged without much - indication about where it came from. - """ - current = LoggingContext.current_context() - try: - res = f(*args, **kwargs) - except: # noqa: E722 - # the assumption here is that the caller doesn't want to be disturbed - # by synchronous exceptions, so let's turn them into Failures. - return defer.fail() - - if not isinstance(res, defer.Deferred): - return res - - if res.called and not res.paused: - # The function should have maintained the logcontext, so we can - # optimise out the messing about - return res - - # The function may have reset the context before returning, so - # we need to restore it now. - ctx = LoggingContext.set_current_context(current) - - # The original context will be restored when the deferred - # completes, but there is nothing waiting for it, so it will - # get leaked into the reactor or some other function which - # wasn't expecting it. We therefore need to reset the context - # here. - # - # (If this feels asymmetric, consider it this way: we are - # effectively forking a new thread of execution. We are - # probably currently within a ``with LoggingContext()`` block, - # which is supposed to have a single entry and exit point. But - # by spawning off another deferred, we are effectively - # adding a new exit point.) - res.addBoth(_set_context_cb, ctx) - return res - - -def make_deferred_yieldable(deferred): - """Given a deferred, make it follow the Synapse logcontext rules: - - If the deferred has completed (or is not actually a Deferred), essentially - does nothing (just returns another completed deferred with the - result/failure). - - If the deferred has not yet completed, resets the logcontext before - returning a deferred. Then, when the deferred completes, restores the - current logcontext before running callbacks/errbacks. - - (This is more-or-less the opposite operation to run_in_background.) - """ - if not isinstance(deferred, defer.Deferred): - return deferred - - if deferred.called and not deferred.paused: - # it looks like this deferred is ready to run any callbacks we give it - # immediately. We may as well optimise out the logcontext faffery. - return deferred - - # ok, we can't be sure that a yield won't block, so let's reset the - # logcontext, and add a callback to the deferred to restore it. - prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) - deferred.addBoth(_set_context_cb, prev_context) - return deferred - - -def _set_context_cb(result, context): - """A callback function which just sets the logging context""" - LoggingContext.set_current_context(context) - return result - - -def defer_to_thread(reactor, f, *args, **kwargs): - """ - Calls the function `f` using a thread from the reactor's default threadpool and - returns the result as a Deferred. - - Creates a new logcontext for `f`, which is created as a child of the current - logcontext (so its CPU usage metrics will get attributed to the current - logcontext). `f` should preserve the logcontext it is given. - - The result deferred follows the Synapse logcontext rules: you should `yield` - on it. - - Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked, and whose threadpool we should use for the - function. - - Normally this will be hs.get_reactor(). - - f (callable): The function to call. - - args: positional arguments to pass to f. - - kwargs: keyword arguments to pass to f. - - Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an - errback if `f` throws an exception. - """ - return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) - - -def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): - """ - A wrapper for twisted.internet.threads.deferToThreadpool, which handles - logcontexts correctly. - - Calls the function `f` using a thread from the given threadpool and returns - the result as a Deferred. - - Creates a new logcontext for `f`, which is created as a child of the current - logcontext (so its CPU usage metrics will get attributed to the current - logcontext). `f` should preserve the logcontext it is given. - - The result deferred follows the Synapse logcontext rules: you should `yield` - on it. - - Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked. Normally this will be hs.get_reactor(). - - threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for - running `f`. Normally this will be hs.get_reactor().getThreadPool(). - - f (callable): The function to call. - - args: positional arguments to pass to f. - - kwargs: keyword arguments to pass to f. - - Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an - errback if `f` throws an exception. - """ - logcontext = LoggingContext.current_context() - - def g(): - with LoggingContext(parent_context=logcontext): - return f(*args, **kwargs) - - return make_deferred_yieldable( - threads.deferToThreadPool(reactor, threadpool, g) - ) +from synapse.logging.context import ( + LoggingContext, + LoggingContextFilter, + PreserveLoggingContext, + defer_to_thread, + make_deferred_yieldable, + nested_logging_context, + preserve_fn, + run_in_background, +) + +__all__ = [ + "defer_to_thread", + "LoggingContext", + "LoggingContextFilter", + "make_deferred_yieldable", + "nested_logging_context", + "preserve_fn", + "PreserveLoggingContext", + "run_in_background", +] diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
index a46bc47ce3..320e8f8174 100644 --- a/synapse/util/logformatter.py +++ b/synapse/util/logformatter.py
@@ -1,5 +1,4 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,40 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Backwards compatibility re-exports of ``synapse.logging.formatter`` functionality. +""" -import logging -import traceback +from synapse.logging.formatter import LogFormatter -from six import StringIO - - -class LogFormatter(logging.Formatter): - """Log formatter which gives more detail for exceptions - - This is the same as the standard log formatter, except that when logging - exceptions [typically via log.foo("msg", exc_info=1)], it prints the - sequence that led up to the point at which the exception was caught. - (Normally only stack frames between the point the exception was raised and - where it was caught are logged). - """ - def __init__(self, *args, **kwargs): - super(LogFormatter, self).__init__(*args, **kwargs) - - def formatException(self, ei): - sio = StringIO() - (typ, val, tb) = ei - - # log the stack above the exception capture point if possible, but - # check that we actually have an f_back attribute to work around - # https://twistedmatrix.com/trac/ticket/9305 - - if tb and hasattr(tb.tb_frame, 'f_back'): - sio.write("Capture point (most recent call last):\n") - traceback.print_stack(tb.tb_frame.f_back, None, sio) - - traceback.print_exception(typ, val, tb, None, sio) - s = sio.getvalue() - sio.close() - if s[-1:] == "\n": - s = s[:-1] - return s +__all__ = ["LogFormatter"] diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py deleted file mode 100644
index ef31458226..0000000000 --- a/synapse/util/logutils.py +++ /dev/null
@@ -1,209 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import inspect -import logging -import time -from functools import wraps -from inspect import getcallargs - -from six import PY3 - -_TIME_FUNC_ID = 0 - - -def _log_debug_as_f(f, msg, msg_args): - name = f.__module__ - logger = logging.getLogger(name) - - if logger.isEnabledFor(logging.DEBUG): - if PY3: - lineno = f.__code__.co_firstlineno - pathname = f.__code__.co_filename - else: - lineno = f.func_code.co_firstlineno - pathname = f.func_code.co_filename - - record = logging.LogRecord( - name=name, - level=logging.DEBUG, - pathname=pathname, - lineno=lineno, - msg=msg, - args=msg_args, - exc_info=None - ) - - logger.handle(record) - - -def log_function(f): - """ Function decorator that logs every call to that function. - """ - func_name = f.__name__ - - @wraps(f) - def wrapped(*args, **kwargs): - name = f.__module__ - logger = logging.getLogger(name) - level = logging.DEBUG - - if logger.isEnabledFor(level): - bound_args = getcallargs(f, *args, **kwargs) - - def format(value): - r = str(value) - if len(r) > 50: - r = r[:50] + "..." - return r - - func_args = [ - "%s=%s" % (k, format(v)) for k, v in bound_args.items() - ] - - msg_args = { - "func_name": func_name, - "args": ", ".join(func_args) - } - - _log_debug_as_f( - f, - "Invoked '%(func_name)s' with args: %(args)s", - msg_args - ) - - return f(*args, **kwargs) - - wrapped.__name__ = func_name - return wrapped - - -def time_function(f): - func_name = f.__name__ - - @wraps(f) - def wrapped(*args, **kwargs): - global _TIME_FUNC_ID - id = _TIME_FUNC_ID - _TIME_FUNC_ID += 1 - - start = time.clock() - - try: - _log_debug_as_f( - f, - "[FUNC START] {%s-%d}", - (func_name, id), - ) - - r = f(*args, **kwargs) - finally: - end = time.clock() - _log_debug_as_f( - f, - "[FUNC END] {%s-%d} %.3f sec", - (func_name, id, end - start,), - ) - - return r - - return wrapped - - -def trace_function(f): - func_name = f.__name__ - linenum = f.func_code.co_firstlineno - pathname = f.func_code.co_filename - - @wraps(f) - def wrapped(*args, **kwargs): - name = f.__module__ - logger = logging.getLogger(name) - level = logging.DEBUG - - s = inspect.currentframe().f_back - - to_print = [ - "\t%s:%s %s. Args: args=%s, kwargs=%s" % ( - pathname, linenum, func_name, args, kwargs - ) - ] - while s: - if True or s.f_globals["__name__"].startswith("synapse"): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - to_print.append( - "\t%s:%d %s. Args: %s" % ( - filename, lineno, function, args_string - ) - ) - - s = s.f_back - - msg = "\nTraceback for %s:\n" % (func_name,) + "\n".join(to_print) - - record = logging.LogRecord( - name=name, - level=level, - pathname=pathname, - lineno=lineno, - msg=msg, - args=None, - exc_info=None - ) - - logger.handle(record) - - return f(*args, **kwargs) - - wrapped.__name__ = func_name - return wrapped - - -def get_previous_frames(): - s = inspect.currentframe().f_back.f_back - to_return = [] - while s: - if s.f_globals["__name__"].startswith("synapse"): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - to_return.append("{{ %s:%d %s - Args: %s }}" % ( - filename, lineno, function, args_string - )) - - s = s.f_back - - return ", ". join(to_return) - - -def get_previous_frame(ignore=[]): - s = inspect.currentframe().f_back.f_back - - while s: - if s.f_globals["__name__"].startswith("synapse"): - if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - return "{{ %s:%d %s - Args: %s }}" % ( - filename, lineno, function, args_string - ) - - s = s.f_back - - return None diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 628a2962d9..631654f297 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py
@@ -74,27 +74,25 @@ def manhole(username, password, globals): twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` """ if not isinstance(password, bytes): - password = password.encode('ascii') + password = password.encode("ascii") - checker = checkers.InMemoryUsernamePasswordDatabaseDontUse( - **{username: password} - ) + checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) rlm = manhole_ssh.TerminalRealm() rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( - SynapseManhole, - dict(globals, __name__="__console__") + SynapseManhole, dict(globals, __name__="__console__") ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY) - factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY) + factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY) + factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY) return factory class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" + def connectionMade(self): super(SynapseManhole, self).connectionMade() @@ -127,7 +125,7 @@ class SynapseManholeInterpreter(ManholeInterpreter): value = SyntaxError(msg, (filename, lineno, offset, line)) sys.last_value = value lines = traceback.format_exception_only(type, value) - self.write(''.join(lines)) + self.write("".join(lines)) def showtraceback(self): """Display the exception that just occurred. @@ -140,6 +138,6 @@ class SynapseManholeInterpreter(ManholeInterpreter): try: # We remove the first stack item because it is our own code. lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) - self.write(''.join(lines)) + self.write("".join(lines)) finally: last_tb = ei = None diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4b4ac5f6c7..7b18455469 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py
@@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from functools import wraps @@ -20,8 +21,8 @@ from prometheus_client import Counter from twisted.internet import defer +from synapse.logging.context import LoggingContext from synapse.metrics import InFlightGauge -from synapse.util.logcontext import LoggingContext logger = logging.getLogger(__name__) @@ -30,108 +31,108 @@ block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"]) block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"]) block_ru_utime = Counter( - "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]) + "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"] +) block_ru_stime = Counter( - "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]) + "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"] +) block_db_txn_count = Counter( - "synapse_util_metrics_block_db_txn_count", "", ["block_name"]) + "synapse_util_metrics_block_db_txn_count", "", ["block_name"] +) # seconds spent waiting for db txns, excluding scheduling time, in this block block_db_txn_duration = Counter( - "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]) + "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"] +) # seconds spent waiting for a db connection, in this block block_db_sched_duration = Counter( - "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]) + "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"] +) # Tracks the number of blocks currently active in_flight = InFlightGauge( - "synapse_util_metrics_block_in_flight", "", + "synapse_util_metrics_block_in_flight", + "", labels=["block_name"], sub_metrics=["real_time_max", "real_time_sum"], ) -def measure_func(name): +def measure_func(name=None): def wrapper(func): - @wraps(func) - @defer.inlineCallbacks - def measured_func(self, *args, **kwargs): - with Measure(self.clock, name): - r = yield func(self, *args, **kwargs) - defer.returnValue(r) + block_name = func.__name__ if name is None else name + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = await func(self, *args, **kwargs) + return r + + else: + + @wraps(func) + @defer.inlineCallbacks + def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = yield func(self, *args, **kwargs) + return r + return measured_func + return wrapper class Measure(object): __slots__ = [ - "clock", "name", "start_context", "start", - "created_context", - "start_usage", + "clock", + "name", + "_logging_context", + "start", ] def __init__(self, clock, name): self.clock = clock self.name = name - self.start_context = None + self._logging_context = None self.start = None - self.created_context = False def __enter__(self): - self.start = self.clock.time() - self.start_context = LoggingContext.current_context() - if not self.start_context: - self.start_context = LoggingContext("Measure") - self.start_context.__enter__() - self.created_context = True - - self.start_usage = self.start_context.get_resource_usage() + if self._logging_context: + raise RuntimeError("Measure() objects cannot be re-used") + self.start = self.clock.time() + parent_context = LoggingContext.current_context() + self._logging_context = LoggingContext( + "Measure[%s]" % (self.name,), parent_context + ) + self._logging_context.__enter__() in_flight.register((self.name,), self._update_in_flight) def __exit__(self, exc_type, exc_val, exc_tb): - if isinstance(exc_type, Exception) or not self.start_context: - return - - in_flight.unregister((self.name,), self._update_in_flight) + if not self._logging_context: + raise RuntimeError("Measure() block exited without being entered") duration = self.clock.time() - self.start + usage = self._logging_context.get_resource_usage() - block_counter.labels(self.name).inc() - block_timer.labels(self.name).inc(duration) - - context = LoggingContext.current_context() - - if context != self.start_context: - logger.warn( - "Context has unexpectedly changed from '%s' to '%s'. (%r)", - self.start_context, context, self.name - ) - return - - if not context: - logger.warn("Expected context. (%r)", self.name) - return + in_flight.unregister((self.name,), self._update_in_flight) + self._logging_context.__exit__(exc_type, exc_val, exc_tb) - current = context.get_resource_usage() - usage = current - self.start_usage try: + block_counter.labels(self.name).inc() + block_timer.labels(self.name).inc(duration) block_ru_utime.labels(self.name).inc(usage.ru_utime) block_ru_stime.labels(self.name).inc(usage.ru_stime) block_db_txn_count.labels(self.name).inc(usage.db_txn_count) block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: - logger.warn( - "Failed to save metrics! OLD: %r, NEW: %r", - self.start_usage, current - ) - - if self.created_context: - self.start_context.__exit__(exc_type, exc_val, exc_tb) + logger.warning("Failed to save metrics! Usage: %s", usage) def _update_in_flight(self, metrics): """Gets called when processing in flight metrics diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 4288312b8a..bb62db4637 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py
@@ -14,12 +14,13 @@ # limitations under the License. import importlib +import importlib.util from synapse.config._base import ConfigError def load_module(provider): - """ Loads a module with its config + """ Loads a synapse module with its config Take a dict with keys 'module' (the module name) and 'config' (the config dict). @@ -28,15 +29,30 @@ def load_module(provider): """ # We need to import the module, and then pick the class out of # that, so we split based on the last dot. - module, clz = provider['module'].rsplit(".", 1) + module, clz = provider["module"].rsplit(".", 1) module = importlib.import_module(module) provider_class = getattr(module, clz) try: - provider_config = provider_class.parse_config(provider["config"]) + provider_config = provider_class.parse_config(provider.get("config")) except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) + raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e)) return provider_class, provider_config + + +def load_python_module(location: str): + """Load a python module, and return a reference to its global namespace + + Args: + location (str): path to the module + + Returns: + python module object + """ + spec = importlib.util.spec_from_file_location(location, location) + if spec is None: + raise Exception("Unable to load module at %s" % (location,)) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) # type: ignore + return mod diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
index a6c30e5265..c8bcbe297a 100644 --- a/synapse/util/msisdn.py +++ b/synapse/util/msisdn.py
@@ -36,6 +36,6 @@ def phone_number_to_msisdn(country, number): phoneNumber = phonenumbers.parse(number, country) except phonenumbers.NumberParseException: raise SynapseError(400, "Unable to parse phone number") - return phonenumbers.format_number( - phoneNumber, phonenumbers.PhoneNumberFormat.E164 - )[1:] + return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[ + 1: + ] diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py new file mode 100644
index 0000000000..3925927f9f --- /dev/null +++ b/synapse/util/patch_inline_callbacks.py
@@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import functools +import sys +from typing import Any, Callable, List + +from twisted.internet import defer +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure + +# Tracks if we've already patched inlineCallbacks +_already_patched = False + + +def do_patch(): + """ + Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit + """ + + from synapse.logging.context import LoggingContext + + global _already_patched + + orig_inline_callbacks = defer.inlineCallbacks + if _already_patched: + return + + def new_inline_callbacks(f): + @functools.wraps(f) + def wrapped(*args, **kwargs): + start_context = LoggingContext.current_context() + changes = [] # type: List[str] + orig = orig_inline_callbacks(_check_yield_points(f, changes)) + + try: + res = orig(*args, **kwargs) + except Exception: + if LoggingContext.current_context() != start_context: + for err in changes: + print(err, file=sys.stderr) + + err = "%s changed context from %s to %s on exception" % ( + f, + start_context, + LoggingContext.current_context(), + ) + print(err, file=sys.stderr) + raise Exception(err) + raise + + if not isinstance(res, Deferred) or res.called: + if LoggingContext.current_context() != start_context: + for err in changes: + print(err, file=sys.stderr) + + err = "Completed %s changed context from %s to %s" % ( + f, + start_context, + LoggingContext.current_context(), + ) + # print the error to stderr because otherwise all we + # see in travis-ci is the 500 error + print(err, file=sys.stderr) + raise Exception(err) + return res + + if LoggingContext.current_context() != LoggingContext.sentinel: + err = ( + "%s returned incomplete deferred in non-sentinel context " + "%s (start was %s)" + ) % (f, LoggingContext.current_context(), start_context) + print(err, file=sys.stderr) + raise Exception(err) + + def check_ctx(r): + if LoggingContext.current_context() != start_context: + for err in changes: + print(err, file=sys.stderr) + err = "%s completion of %s changed context from %s to %s" % ( + "Failure" if isinstance(r, Failure) else "Success", + f, + start_context, + LoggingContext.current_context(), + ) + print(err, file=sys.stderr) + raise Exception(err) + return r + + res.addBoth(check_ctx) + return res + + return wrapped + + defer.inlineCallbacks = new_inline_callbacks + _already_patched = True + + +def _check_yield_points(f: Callable, changes: List[str]): + """Wraps a generator that is about to be passed to defer.inlineCallbacks + checking that after every yield the log contexts are correct. + + It's perfectly valid for log contexts to change within a function, e.g. due + to new Measure blocks, so such changes are added to the given `changes` + list instead of triggering an exception. + + Args: + f: generator function to wrap + changes: A list of strings detailing how the contexts + changed within a function. + + Returns: + function + """ + + from synapse.logging.context import LoggingContext + + @functools.wraps(f) + def check_yield_points_inner(*args, **kwargs): + gen = f(*args, **kwargs) + + last_yield_line_no = gen.gi_frame.f_lineno + result = None # type: Any + while True: + expected_context = LoggingContext.current_context() + + try: + isFailure = isinstance(result, Failure) + if isFailure: + d = result.throwExceptionIntoGenerator(gen) + else: + d = gen.send(result) + except (StopIteration, defer._DefGen_Return) as e: + if LoggingContext.current_context() != expected_context: + # This happens when the context is lost sometime *after* the + # final yield and returning. E.g. we forgot to yield on a + # function that returns a deferred. + # + # We don't raise here as it's perfectly valid for contexts to + # change in a function, as long as it sets the correct context + # on resolving (which is checked separately). + err = ( + "Function %r returned and changed context from %s to %s," + " in %s between %d and end of func" + % ( + f.__qualname__, + expected_context, + LoggingContext.current_context(), + f.__code__.co_filename, + last_yield_line_no, + ) + ) + changes.append(err) + return getattr(e, "value", None) + + frame = gen.gi_frame + + if isinstance(d, defer.Deferred) and not d.called: + # This happens if we yield on a deferred that doesn't follow + # the log context rules without wrapping in a `make_deferred_yieldable`. + # We raise here as this should never happen. + if LoggingContext.current_context() is not LoggingContext.sentinel: + err = ( + "%s yielded with context %s rather than sentinel," + " yielded on line %d in %s" + % ( + frame.f_code.co_name, + LoggingContext.current_context(), + frame.f_lineno, + frame.f_code.co_filename, + ) + ) + raise Exception(err) + + try: + result = yield d + except Exception as e: + result = Failure(e) + + if LoggingContext.current_context() != expected_context: + + # This happens because the context is lost sometime *after* the + # previous yield and *after* the current yield. E.g. the + # deferred we waited on didn't follow the rules, or we forgot to + # yield on a function between the two yield points. + # + # We don't raise here as its perfectly valid for contexts to + # change in a function, as long as it sets the correct context + # on resolving (which is checked separately). + err = ( + "%s changed context from %s to %s, happened between lines %d and %d in %s" + % ( + frame.f_code.co_name, + expected_context, + LoggingContext.current_context(), + last_yield_line_no, + frame.f_lineno, + frame.f_code.co_filename, + ) + ) + changes.append(err) + + last_yield_line_no = frame.f_lineno + + return check_yield_points_inner diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index b146d137f4..5ca4521ce3 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py
@@ -20,7 +20,7 @@ import logging from twisted.internet import defer from synapse.api.errors import LimitExceededError -from synapse.util.logcontext import ( +from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, run_in_background, @@ -36,9 +36,11 @@ class FederationRateLimiter(object): clock (Clock) config (FederationRateLimitConfig) """ - self.clock = clock - self._config = config - self.ratelimiters = {} + + def new_limiter(): + return _PerHostRatelimiter(clock=clock, config=config) + + self.ratelimiters = collections.defaultdict(new_limiter) def ratelimit(self, host): """Used to ratelimit an incoming request from given host @@ -53,15 +55,9 @@ class FederationRateLimiter(object): host (str): Origin of incoming request. Returns: - _PerHostRatelimiter + context manager which returns a deferred. """ - return self.ratelimiters.setdefault( - host, - _PerHostRatelimiter( - clock=self.clock, - config=self._config, - ) - ).ratelimit() + return self.ratelimiters[host].ratelimit() class _PerHostRatelimiter(object): @@ -112,8 +108,7 @@ class _PerHostRatelimiter(object): # remove any entries from request_times which aren't within the window self.request_times[:] = [ - r for r in self.request_times - if time_now - r < self.window_size + r for r in self.request_times if time_now - r < self.window_size ] # reject the request if we already have too many queued up (either @@ -121,15 +116,13 @@ class _PerHostRatelimiter(object): queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) if queue_size > self.reject_limit: raise LimitExceededError( - retry_after_ms=int( - self.window_size / self.sleep_limit - ), + retry_after_ms=int(self.window_size / self.sleep_limit) ) self.request_times.append(time_now) def queue_request(): - if len(self.current_processing) > self.concurrent_requests: + if len(self.current_processing) >= self.concurrent_requests: queue_defer = defer.Deferred() self.ready_request_queue[request_id] = queue_defer logger.info( @@ -143,22 +136,18 @@ class _PerHostRatelimiter(object): logger.debug( "Ratelimit [%s]: len(self.request_times)=%d", - id(request_id), len(self.request_times), + id(request_id), + len(self.request_times), ) if len(self.request_times) > self.sleep_limit: - logger.debug( - "Ratelimiter: sleeping request for %f sec", self.sleep_sec, - ) + logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec) ret_defer = run_in_background(self.clock.sleep, self.sleep_sec) self.sleeping_requests.add(request_id) def on_wait_finished(_): - logger.debug( - "Ratelimit [%s]: Finished sleeping", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) self.sleeping_requests.discard(request_id) queue_defer = queue_request() return queue_defer @@ -168,10 +157,7 @@ class _PerHostRatelimiter(object): ret_defer = queue_request() def on_start(r): - logger.debug( - "Ratelimit [%s]: Processing req", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Processing req", id(request_id)) self.current_processing.add(request_id) return r @@ -193,10 +179,7 @@ class _PerHostRatelimiter(object): return make_deferred_yieldable(ret_defer) def _on_exit(self, request_id): - logger.debug( - "Ratelimit [%s]: Processed req", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Processed req", id(request_id)) self.current_processing.discard(request_id) try: # start processing the next item on the queue. diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 1a77456498..af69587196 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py
@@ -17,11 +17,20 @@ import random from twisted.internet import defer -import synapse.util.logcontext +import synapse.logging.context from synapse.api.errors import CodeMessageException logger = logging.getLogger(__name__) +# the intial backoff, after the first transaction fails +MIN_RETRY_INTERVAL = 10 * 60 * 1000 + +# how much we multiply the backoff by after each subsequent fail +RETRY_MULTIPLIER = 5 + +# a cap on the backoff. (Essentially none) +MAX_RETRY_INTERVAL = 2 ** 62 + class NotRetryingDestination(Exception): def __init__(self, retry_last_ts, retry_interval, destination): @@ -71,11 +80,13 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) # We aren't ready to retry that destination. raise """ + failure_ts = None retry_last_ts, retry_interval = (0, 0) retry_timings = yield store.get_destination_retry_timings(destination) if retry_timings: + failure_ts = retry_timings["failure_ts"] retry_last_ts, retry_interval = ( retry_timings["retry_last_ts"], retry_timings["retry_interval"], @@ -95,15 +106,14 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) # maximum backoff even though it might only have been down briefly backoff_on_failure = not ignore_backoff - defer.returnValue( - RetryDestinationLimiter( - destination, - clock, - store, - retry_interval, - backoff_on_failure=backoff_on_failure, - **kwargs - ) + return RetryDestinationLimiter( + destination, + clock, + store, + failure_ts, + retry_interval, + backoff_on_failure=backoff_on_failure, + **kwargs ) @@ -113,10 +123,8 @@ class RetryDestinationLimiter(object): destination, clock, store, + failure_ts, retry_interval, - min_retry_interval=10 * 60 * 1000, - max_retry_interval=24 * 60 * 60 * 1000, - multiplier_retry_interval=5, backoff_on_404=False, backoff_on_failure=True, ): @@ -129,15 +137,11 @@ class RetryDestinationLimiter(object): destination (str) clock (Clock) store (DataStore) + failure_ts (int|None): when this destination started failing (in ms since + the epoch), or zero if the last request was successful retry_interval (int): The next retry interval taken from the database in milliseconds, or zero if the last request was successful. - min_retry_interval (int): The minimum retry interval to use after - a failed request, in milliseconds. - max_retry_interval (int): The maximum retry interval to use after - a failed request, in milliseconds. - multiplier_retry_interval (int): The multiplier to use to increase - the retry interval after a failed request. backoff_on_404 (bool): Back off if we get a 404 backoff_on_failure (bool): set to False if we should not increase the @@ -147,10 +151,8 @@ class RetryDestinationLimiter(object): self.store = store self.destination = destination + self.failure_ts = failure_ts self.retry_interval = retry_interval - self.min_retry_interval = min_retry_interval - self.max_retry_interval = max_retry_interval - self.multiplier_retry_interval = multiplier_retry_interval self.backoff_on_404 = backoff_on_404 self.backoff_on_failure = backoff_on_failure @@ -191,6 +193,7 @@ class RetryDestinationLimiter(object): logger.debug( "Connection to %s was successful; clearing backoff", self.destination ) + self.failure_ts = None retry_last_ts = 0 self.retry_interval = 0 elif not self.backoff_on_failure: @@ -198,13 +201,14 @@ class RetryDestinationLimiter(object): else: # We couldn't connect. if self.retry_interval: - self.retry_interval *= self.multiplier_retry_interval - self.retry_interval *= int(random.uniform(0.8, 1.4)) + self.retry_interval = int( + self.retry_interval * RETRY_MULTIPLIER * random.uniform(0.8, 1.4) + ) - if self.retry_interval >= self.max_retry_interval: - self.retry_interval = self.max_retry_interval + if self.retry_interval >= MAX_RETRY_INTERVAL: + self.retry_interval = MAX_RETRY_INTERVAL else: - self.retry_interval = self.min_retry_interval + self.retry_interval = MIN_RETRY_INTERVAL logger.info( "Connection to %s was unsuccessful (%s(%s)); backoff now %i", @@ -215,14 +219,20 @@ class RetryDestinationLimiter(object): ) retry_last_ts = int(self.clock.time_msec()) + if self.failure_ts is None: + self.failure_ts = retry_last_ts + @defer.inlineCallbacks def store_retry_timings(): try: yield self.store.set_destination_retry_timings( - self.destination, retry_last_ts, self.retry_interval + self.destination, + self.failure_ts, + retry_last_ts, + self.retry_interval, ) except Exception: logger.exception("Failed to store destination_retry_timings") # we deliberately do this in the background. - synapse.util.logcontext.run_in_background(store_retry_timings) + synapse.logging.context.run_in_background(store_retry_timings) diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py
index 6c0f2bb0cf..207cd17c2a 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py
@@ -33,4 +33,4 @@ def change_resource_limit(soft_file_no): resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY) ) except (ValueError, resource.error) as e: - logger.warn("Failed to set file or core limit: %s", e) + logger.warning("Failed to set file or core limit: %s", e) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 5fb18ee1f8..2c0dcb5208 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py
@@ -24,26 +24,25 @@ from six.moves import range from synapse.api.errors import Codes, SynapseError -_string_with_symbols = ( - string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" -) +_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" + +# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken +# Note: The : character is allowed here for older clients, but will be removed in a +# future release. Context: https://github.com/matrix-org/synapse/issues/6766 +client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$") # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure # we get cryptographically-secure randoms. rand = random.SystemRandom() -client_secret_regex = re.compile(r"^[0-9a-zA-Z.=_-]+$") - def random_string(length): - return ''.join(rand.choice(string.ascii_letters) for _ in range(length)) + return "".join(rand.choice(string.ascii_letters) for _ in range(length)) def random_string_with_symbols(length): - return ''.join( - rand.choice(_string_with_symbols) for _ in range(length) - ) + return "".join(rand.choice(_string_with_symbols) for _ in range(length)) def is_ascii(s): @@ -51,7 +50,7 @@ def is_ascii(s): if PY3: if isinstance(s, bytes): try: - s.decode('ascii').encode('ascii') + s.decode("ascii").encode("ascii") except UnicodeDecodeError: return False except UnicodeEncodeError: @@ -110,13 +109,13 @@ def exception_to_unicode(e): # and instead look at what is in the args member. if len(e.args) == 0: - return u"" + return "" elif len(e.args) > 1: return six.text_type(repr(e.args)) msg = e.args[0] if isinstance(msg, bytes): - return msg.decode('utf-8', errors='replace') + return msg.decode("utf-8", errors="replace") else: return msg diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 4cc7d27ce5..34ce7cac16 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py
@@ -36,25 +36,26 @@ def check_3pid_allowed(hs, medium, address): if hs.config.check_is_for_allowed_local_3pids: data = yield hs.get_simple_http_client().get_json( - "https://%s%s" % ( + "https://%s%s" + % ( hs.config.check_is_for_allowed_local_3pids, - "/_matrix/identity/api/v1/internal-info" + "/_matrix/identity/api/v1/internal-info", ), - {'medium': medium, 'address': address} + {"medium": medium, "address": address}, ) # Check for invalid response - if 'hs' not in data and 'shadow_hs' not in data: + if "hs" not in data and "shadow_hs" not in data: defer.returnValue(False) # Check if this user is intended to register for this homeserver if ( - data.get('hs') != hs.config.server_name - and data.get('shadow_hs') != hs.config.server_name + data.get("hs") != hs.config.server_name + and data.get("shadow_hs") != hs.config.server_name ): defer.returnValue(False) - if data.get('requires_invite', False) and not data.get('invited', False): + if data.get("requires_invite", False) and not data.get("invited", False): # Requires an invite but hasn't been invited defer.returnValue(False) @@ -64,11 +65,13 @@ def check_3pid_allowed(hs, medium, address): for constraint in hs.config.allowed_local_3pids: logger.debug( "Checking 3PID %s (%s) against %s (%s)", - address, medium, constraint['pattern'], constraint['medium'], + address, + medium, + constraint["pattern"], + constraint["medium"], ) - if ( - medium == constraint['medium'] and - re.match(constraint['pattern'], address) + if medium == constraint["medium"] and re.match( + constraint["pattern"], address ): defer.returnValue(True) else: diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 3baba3225a..ab7d03af3a 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py
@@ -22,63 +22,87 @@ logger = logging.getLogger(__name__) def get_version_string(module): + """Given a module calculate a git-aware version string for it. + + If called on a module not in a git checkout will return `__verison__`. + + Args: + module (module) + + Returns: + str + """ + + cached_version = getattr(module, "_synapse_version_string_cache", None) + if cached_version: + return cached_version + + version_string = module.__version__ + try: - null = open(os.devnull, 'w') + null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) + try: - git_branch = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_branch = ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) git_branch = "b=" + git_branch - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): + # FileNotFoundError can arise when git is not installed git_branch = "" try: - git_tag = subprocess.check_output( - ['git', 'describe', '--exact-match'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_tag = ( + subprocess.check_output( + ["git", "describe", "--exact-match"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) git_tag = "t=" + git_tag - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_tag = "" try: - git_commit = subprocess.check_output( - ['git', 'rev-parse', '--short', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') - except subprocess.CalledProcessError: + git_commit = ( + subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) + except (subprocess.CalledProcessError, FileNotFoundError): git_commit = "" try: dirty_string = "-this_is_a_dirty_checkout" - is_dirty = subprocess.check_output( - ['git', 'describe', '--dirty=' + dirty_string], - stderr=null, - cwd=cwd, - ).strip().decode('ascii').endswith(dirty_string) + is_dirty = ( + subprocess.check_output( + ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + .endswith(dirty_string) + ) git_dirty = "dirty" if is_dirty else "" - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_dirty = "" if git_branch or git_tag or git_commit or git_dirty: git_version = ",".join( - s for s in - (git_branch, git_tag, git_commit, git_dirty,) - if s + s for s in (git_branch, git_tag, git_commit, git_dirty) if s ) - return ( - "%s (%s)" % ( - module.__version__, git_version, - ) - ) + version_string = "%s (%s)" % (module.__version__, git_version) except Exception as e: logger.info("Failed to check for git repository: %s", e) - return module.__version__ + module._synapse_version_string_cache = version_string + + return version_string diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 7a9e45aca9..9bf6a44f75 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py
@@ -69,9 +69,7 @@ class WheelTimer(object): # Add empty entries between the end of the current list and when we want # to insert. This ensures there are no gaps. - self.entries.extend( - _Entry(key) for key in range(last_key, then_key + 1) - ) + self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1)) self.entries[-1].queue.append(obj)