diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 8f5a526800..60f0de70f7 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -15,13 +15,12 @@
import logging
import re
-from itertools import islice
import attr
from twisted.internet import defer, task
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.logging import context
logger = logging.getLogger(__name__)
@@ -40,15 +39,16 @@ class Clock(object):
Args:
reactor: The Twisted reactor to use.
"""
+
_reactor = attr.ib()
@defer.inlineCallbacks
def sleep(self, seconds):
d = defer.Deferred()
- with PreserveLoggingContext():
+ with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
res = yield d
- defer.returnValue(res)
+ return res
def time(self):
"""Returns the current system time in seconds since epoch."""
@@ -61,7 +61,10 @@ class Clock(object):
def looping_call(self, f, msec, *args, **kwargs):
"""Call a function repeatedly.
- Waits `msec` initially before calling `f` for the first time.
+ Waits `msec` initially before calling `f` for the first time.
+
+ Note that the function will be called with no logcontext, so if it is anything
+ other than trivial, you probably want to wrap it in run_as_background_process.
Args:
f(function): The function to call repeatedly.
@@ -72,25 +75,27 @@ class Clock(object):
call = task.LoopingCall(f, *args, **kwargs)
call.clock = self._reactor
d = call.start(msec / 1000.0, now=False)
- d.addErrback(
- log_failure, "Looping call died", consumeErrors=False,
- )
+ d.addErrback(log_failure, "Looping call died", consumeErrors=False)
return call
def call_later(self, delay, callback, *args, **kwargs):
"""Call something later
+ Note that the function will be called with no logcontext, so if it is anything
+ other than trivial, you probably want to wrap it in run_as_background_process.
+
Args:
delay(float): How long to wait in seconds.
callback(function): Function to call
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
+
def wrapped_callback(*args, **kwargs):
- with PreserveLoggingContext():
+ with context.PreserveLoggingContext():
callback(*args, **kwargs)
- with PreserveLoggingContext():
+ with context.PreserveLoggingContext():
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer, ignore_errs=False):
@@ -101,22 +106,6 @@ class Clock(object):
raise
-def batch_iter(iterable, size):
- """batch an iterable up into tuples with a maximum size
-
- Args:
- iterable (iterable): the iterable to slice
- size (int): the maximum batch size
-
- Returns:
- an iterator over the chunks
- """
- # make sure we can deal with iterables like lists too
- sourceiter = iter(iterable)
- # call islice until it returns an empty tuple
- return iter(lambda: tuple(islice(sourceiter, size)), ())
-
-
def log_failure(failure, msg, consumeErrors=True):
"""Creates a function suitable for passing to `Deferred.addErrback` that
logs any failures that occur.
@@ -131,12 +120,7 @@ def log_failure(failure, msg, consumeErrors=True):
"""
logger.error(
- msg,
- exc_info=(
- failure.type,
- failure.value,
- failure.getTracebackObject()
- )
+ msg, exc_info=(failure.type, failure.value, failure.getTracebackObject())
)
if not consumeErrors:
@@ -154,12 +138,12 @@ def glob_to_regex(glob):
Returns:
re.RegexObject
"""
- res = ''
+ res = ""
for c in glob:
- if c == '*':
- res = res + '.*'
- elif c == '?':
- res = res + '.'
+ if c == "*":
+ res = res + ".*"
+ elif c == "?":
+ res = res + "."
else:
res = res + re.escape(c)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 7253ba120f..581dffd8a0 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,23 +13,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import collections
import logging
from contextlib import contextmanager
+from typing import Dict, Sequence, Set, Union
from six.moves import range
+import attr
+
from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.python import failure
-from synapse.util import Clock, logcontext, unwrapFirstError
-
-from .logcontext import (
+from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
+from synapse.util import Clock, unwrapFirstError
logger = logging.getLogger(__name__)
@@ -70,6 +73,10 @@ class ObservableDeferred(object):
def errback(f):
object.__setattr__(self, "_result", (False, f))
while self._observers:
+ # This is a little bit of magic to correctly propagate stack
+ # traces when we `await` on one of the observer deferreds.
+ f.value.__failure__ = f
+
try:
# TODO: Handle errors here.
self._observers.pop().errback(f)
@@ -83,11 +90,12 @@ class ObservableDeferred(object):
deferred.addCallbacks(callback, errback)
- def observe(self):
+ def observe(self) -> defer.Deferred:
"""Observe the underlying deferred.
- Can return either a deferred if the underlying deferred is still pending
- (or has failed), or the actual value. Callers may need to use maybeDeferred.
+ This returns a brand new deferred that is resolved when the underlying
+ deferred is resolved. Interacting with the returned deferred does not
+ effect the underdlying deferred.
"""
if not self._result:
d = defer.Deferred()
@@ -95,13 +103,14 @@ class ObservableDeferred(object):
def remove(r):
self._observers.discard(d)
return r
+
d.addBoth(remove)
self._observers.add(d)
return d
else:
success, res = self._result
- return res if success else defer.fail(res)
+ return defer.succeed(res) if success else defer.fail(res)
def observers(self):
return self._observers
@@ -123,7 +132,9 @@ class ObservableDeferred(object):
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
- id(self), self._result, self._deferred,
+ id(self),
+ self._result,
+ self._deferred,
)
@@ -132,9 +143,9 @@ def concurrently_execute(func, args, limit):
the number of concurrent executions.
Args:
- func (func): Function to execute, should return a deferred.
- args (list): List of arguments to pass to func, each invocation of func
- gets a signle argument.
+ func (func): Function to execute, should return a deferred or coroutine.
+ args (Iterable): List of arguments to pass to func, each invocation of func
+ gets a single argument.
limit (int): Maximum number of conccurent executions.
Returns:
@@ -142,18 +153,19 @@ def concurrently_execute(func, args, limit):
"""
it = iter(args)
- @defer.inlineCallbacks
- def _concurrently_execute_inner():
+ async def _concurrently_execute_inner():
try:
while True:
- yield func(next(it))
+ await maybe_awaitable(func(next(it)))
except StopIteration:
pass
- return logcontext.make_deferred_yieldable(defer.gatherResults([
- run_in_background(_concurrently_execute_inner)
- for _ in range(limit)
- ], consumeErrors=True)).addErrback(unwrapFirstError)
+ return make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(_concurrently_execute_inner) for _ in range(limit)],
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
def yieldable_gather_results(func, iter, *args, **kwargs):
@@ -169,10 +181,12 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
Deferred[list]: Resolved when all functions have been invoked, or errors if
one of the function calls fails.
"""
- return logcontext.make_deferred_yieldable(defer.gatherResults([
- run_in_background(func, item, *args, **kwargs)
- for item in iter
- ], consumeErrors=True)).addErrback(unwrapFirstError)
+ return make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(func, item, *args, **kwargs) for item in iter],
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
class Linearizer(object):
@@ -185,6 +199,7 @@ class Linearizer(object):
# do some work.
"""
+
def __init__(self, name=None, max_count=1, clock=None):
"""
Args:
@@ -197,6 +212,7 @@ class Linearizer(object):
if not clock:
from twisted.internet import reactor
+
clock = Clock(reactor)
self._clock = clock
self.max_count = max_count
@@ -205,7 +221,9 @@ class Linearizer(object):
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
- self.key_to_defer = {}
+ self.key_to_defer = (
+ {}
+ ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@@ -221,7 +239,7 @@ class Linearizer(object):
res = self._await_lock(key)
else:
logger.debug(
- "Acquired uncontended linearizer lock %r for key %r", self.name, key,
+ "Acquired uncontended linearizer lock %r for key %r", self.name, key
)
entry[0] += 1
res = defer.succeed(None)
@@ -266,9 +284,7 @@ class Linearizer(object):
"""
entry = self.key_to_defer[key]
- logger.debug(
- "Waiting to acquire linearizer lock %r for key %r", self.name, key,
- )
+ logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
new_defer = make_deferred_yieldable(defer.Deferred())
entry[1][new_defer] = 1
@@ -293,14 +309,14 @@ class Linearizer(object):
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
logger.debug(
- "Cancelling wait for linearizer lock %r for key %r",
- self.name, key,
+ "Cancelling wait for linearizer lock %r for key %r", self.name, key
)
else:
- logger.warn(
+ logger.warning(
"Unexpected exception waiting for linearizer lock %r for key %r",
- self.name, key,
+ self.name,
+ key,
)
# we just have to take ourselves back out of the queue.
@@ -334,10 +350,10 @@ class ReadWriteLock(object):
def __init__(self):
# Latest readers queued
- self.key_to_current_readers = {}
+ self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued
- self.key_to_current_writer = {}
+ self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks
def read(self, key):
@@ -360,7 +376,7 @@ class ReadWriteLock(object):
new_defer.callback(None)
self.key_to_current_readers.get(key, set()).discard(new_defer)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
@defer.inlineCallbacks
def write(self, key):
@@ -390,7 +406,7 @@ class ReadWriteLock(object):
if self.key_to_current_writer[key] == new_defer:
self.key_to_current_writer.pop(key)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
def _cancelled_to_timed_out_error(value, timeout):
@@ -438,7 +454,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
try:
deferred.cancel()
- except: # noqa: E722, if we throw any exception it'll break time outs
+ except: # noqa: E722, if we throw any exception it'll break time outs
logger.exception("Canceller failed during timeout")
if not new_d.called:
@@ -473,3 +489,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred.addCallbacks(success_cb, failure_cb)
return new_d
+
+
+@attr.s(slots=True, frozen=True)
+class DoneAwaitable(object):
+ """Simple awaitable that returns the provided value.
+ """
+
+ value = attr.ib()
+
+ def __await__(self):
+ return self
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ raise StopIteration(self.value)
+
+
+def maybe_awaitable(value):
+ """Convert a value to an awaitable if not already an awaitable.
+ """
+
+ if hasattr(value, "__await__"):
+ return value
+
+ return DoneAwaitable(value)
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index f37d5bec08..da5077b471 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@
import logging
import os
+from typing import Dict
import six
from six.moves import intern
@@ -36,7 +38,7 @@ def get_cache_factor_for(cache_name):
caches_by_name = {}
-collectors_by_name = {}
+collectors_by_name = {} # type: Dict
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
@@ -51,7 +53,19 @@ response_cache_evicted = Gauge(
response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
-def register_cache(cache_type, cache_name, cache):
+def register_cache(cache_type, cache_name, cache, collect_callback=None):
+ """Register a cache object for metric collection.
+
+ Args:
+ cache_type (str):
+ cache_name (str): name of the cache
+ cache (object): cache itself
+ collect_callback (callable|None): if not None, a function which is called during
+ metric collection to update additional metrics.
+
+ Returns:
+ CacheMetric: an object which provides inc_{hits,misses,evictions} methods
+ """
# Check if the metric is already registered. Unregister it, if so.
# This usually happens during tests, as at runtime these caches are
@@ -90,8 +104,10 @@ def register_cache(cache_type, cache_name, cache):
cache_hits.labels(cache_name).set(self.hits)
cache_evicted.labels(cache_name).set(self.evicted_size)
cache_total.labels(cache_name).set(self.hits + self.misses)
+ if collect_callback:
+ collect_callback()
except Exception as e:
- logger.warn("Error calculating metrics for %s: %s", cache_name, e)
+ logger.warning("Error calculating metrics for %s: %s", cache_name, e)
raise
yield GaugeMetricFamily("__unused", "")
@@ -104,8 +120,8 @@ def register_cache(cache_type, cache_name, cache):
KNOWN_KEYS = {
- key: key for key in
- (
+ key: key
+ for key in (
"auth_events",
"content",
"depth",
@@ -150,7 +166,7 @@ def intern_dict(dictionary):
def _intern_known_values(key, value):
- intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",)
+ intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
if key in intern_keys:
return intern_string(value)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 187510576a..2e8f6543e5 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,32 +17,53 @@ import functools
import inspect
import logging
import threading
-from collections import namedtuple
+from typing import Any, Tuple, Union, cast
+from weakref import WeakValueDictionary
-import six
-from six import itervalues, string_types
+from six import itervalues
+
+from prometheus_client import Gauge
+from typing_extensions import Protocol
from twisted.internet import defer
-from synapse.util import logcontext, unwrapFirstError
+from synapse.logging.context import make_deferred_yieldable, preserve_fn
+from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.stringutils import to_ascii
from . import register_cache
logger = logging.getLogger(__name__)
+CacheKey = Union[Tuple, Any]
+
+
+class _CachedFunction(Protocol):
+ invalidate = None # type: Any
+ invalidate_all = None # type: Any
+ invalidate_many = None # type: Any
+ prefill = None # type: Any
+ cache = None # type: Any
+ num_args = None # type: Any
+
+ def __name__(self):
+ ...
+
+
+cache_pending_metric = Gauge(
+ "synapse_util_caches_cache_pending",
+ "Number of lookups currently pending for this cache",
+ ["name"],
+)
_CacheSentinel = object()
class CacheEntry(object):
- __slots__ = [
- "deferred", "callbacks", "invalidated"
- ]
+ __slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(self, deferred, callbacks):
self.deferred = deferred
@@ -73,7 +94,9 @@ class Cache(object):
self._pending_deferred_cache = cache_type()
self.cache = LruCache(
- max_size=max_entries, keylen=keylen, cache_type=cache_type,
+ max_size=max_entries,
+ keylen=keylen,
+ cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
)
@@ -81,11 +104,19 @@ class Cache(object):
self.name = name
self.keylen = keylen
self.thread = None
- self.metrics = register_cache("cache", name, self.cache)
+ self.metrics = register_cache(
+ "cache",
+ name,
+ self.cache,
+ collect_callback=self._metrics_collection_callback,
+ )
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
+ def _metrics_collection_callback(self):
+ cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
+
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@@ -107,7 +138,7 @@ class Cache(object):
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
- Either a Deferred or the raw result
+ Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
@@ -131,12 +162,14 @@ class Cache(object):
return default
def set(self, key, value, callback=None):
+ if not isinstance(value, defer.Deferred):
+ raise TypeError("not a Deferred")
+
callbacks = [callback] if callback else []
self.check_thread()
- entry = CacheEntry(
- deferred=value,
- callbacks=callbacks,
- )
+ observable = ObservableDeferred(value, consumeErrors=True)
+ observer = defer.maybeDeferred(observable.observe)
+ entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
@@ -144,20 +177,31 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
- def shuffle(result):
+ def compare_and_pop():
+ """Check if our entry is still the one in _pending_deferred_cache, and
+ if so, pop it.
+
+ Returns true if the entries matched.
+ """
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
+ return True
+
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ return False
+
+ def cb(result):
+ if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
@@ -165,9 +209,16 @@ class Cache(object):
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
- return result
- entry.deferred.addCallback(shuffle)
+ def eb(_fail):
+ compare_and_pop()
+ entry.invalidate()
+
+ # once the deferred completes, we can move the entry from the
+ # _pending_deferred_cache to the real cache.
+ #
+ observer.addCallbacks(cb, eb)
+ return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
@@ -191,9 +242,7 @@ class Cache(object):
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
- raise TypeError(
- "The cache key must be a tuple not %r" % (type(key),)
- )
+ raise TypeError("The cache key must be a tuple not %r" % (type(key),))
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
@@ -212,7 +261,9 @@ class Cache(object):
class _CacheDescriptorBase(object):
- def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+ def __init__(
+ self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
+ ):
self.orig = orig
if inlineCallbacks:
@@ -220,7 +271,7 @@ class _CacheDescriptorBase(object):
else:
self.function_to_call = orig
- arg_spec = inspect.getargspec(orig)
+ arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
if "cache_context" in all_args:
@@ -244,29 +295,25 @@ class _CacheDescriptorBase(object):
raise Exception(
"Not enough explicit positional arguments to key off for %r: "
"got %i args, but wanted %i. (@cached cannot key off *args or "
- "**kwargs)"
- % (orig.__name__, len(all_args), num_args)
+ "**kwargs)" % (orig.__name__, len(all_args), num_args)
)
self.num_args = num_args
# list of the names of the args used as the cache key
- self.arg_names = all_args[1:num_args + 1]
+ self.arg_names = all_args[1 : num_args + 1]
# self.arg_defaults is a map of arg name to its default value for each
# argument that has a default value
if arg_spec.defaults:
- self.arg_defaults = dict(zip(
- all_args[-len(arg_spec.defaults):],
- arg_spec.defaults
- ))
+ self.arg_defaults = dict(
+ zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults)
+ )
else:
self.arg_defaults = {}
if "cache_context" in self.arg_names:
- raise Exception(
- "cache_context arg cannot be included among the cache keys"
- )
+ raise Exception("cache_context arg cannot be included among the cache keys")
self.add_cache_context = cache_context
@@ -297,19 +344,31 @@ class CacheDescriptor(_CacheDescriptorBase):
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
- defer.returnValue(r1 + r2)
+ return r1 + r2
Args:
num_args (int): number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
"""
- def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
- inlineCallbacks=False, cache_context=False, iterable=False):
+
+ def __init__(
+ self,
+ orig,
+ max_entries=1000,
+ num_args=None,
+ tree=False,
+ inlineCallbacks=False,
+ cache_context=False,
+ iterable=False,
+ ):
super(CacheDescriptor, self).__init__(
- orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
- cache_context=cache_context)
+ orig,
+ num_args=num_args,
+ inlineCallbacks=inlineCallbacks,
+ cache_context=cache_context,
+ )
max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
@@ -356,12 +415,14 @@ class CacheDescriptor(_CacheDescriptorBase):
return args[0]
else:
return self.arg_defaults[nm]
+
else:
+
def get_cache_key(args, kwargs):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig)
- def wrapped(*args, **kwargs):
+ def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -371,7 +432,7 @@ class CacheDescriptor(_CacheDescriptorBase):
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
- kwargs["cache_context"] = _CacheContext(cache, cache_key)
+ kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
try:
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
@@ -379,12 +440,11 @@ class CacheDescriptor(_CacheDescriptorBase):
if isinstance(cached_result_d, ObservableDeferred):
observer = cached_result_d.observe()
else:
- observer = cached_result_d
+ observer = defer.succeed(cached_result_d)
except KeyError:
ret = defer.maybeDeferred(
- logcontext.preserve_fn(self.function_to_call),
- obj, *args, **kwargs
+ preserve_fn(self.function_to_call), obj, *args, **kwargs
)
def onErr(f):
@@ -393,20 +453,12 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
- # If our cache_key is a string on py2, try to convert to ascii
- # to save a bit of space in large caches. Py3 does this
- # internally automatically.
- if six.PY2 and isinstance(cache_key, string_types):
- cache_key = to_ascii(cache_key)
-
- result_d = ObservableDeferred(ret, consumeErrors=True)
- cache.set(cache_key, result_d, callback=invalidate_callback)
+ result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
- if isinstance(observer, defer.Deferred):
- return logcontext.make_deferred_yieldable(observer)
- else:
- return observer
+ return make_deferred_yieldable(observer)
+
+ wrapped = cast(_CachedFunction, _wrapped)
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
@@ -432,13 +484,13 @@ class CacheListDescriptor(_CacheDescriptorBase):
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped function.
- Once wrapped, the function returns either a Deferred which resolves to
- the list of results, or (if all results were cached), just the list of
- results.
+ Once wrapped, the function returns a Deferred which resolves to the list
+ of results.
"""
- def __init__(self, orig, cached_method_name, list_name, num_args=None,
- inlineCallbacks=False):
+ def __init__(
+ self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
+ ):
"""
Args:
orig (function)
@@ -451,7 +503,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
be wrapped by defer.inlineCallbacks
"""
super(CacheListDescriptor, self).__init__(
- orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
+ orig, num_args=num_args, inlineCallbacks=inlineCallbacks
+ )
self.list_name = list_name
@@ -463,7 +516,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
- % (self.list_name, cached_method_name,)
+ % (self.list_name, cached_method_name)
)
def __get__(self, obj, objtype=None):
@@ -494,8 +547,10 @@ class CacheListDescriptor(_CacheDescriptorBase):
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
+
def arg_to_cache_key(arg):
return arg
+
else:
keylist = list(keyargs)
@@ -505,8 +560,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
for arg in list_args:
try:
- res = cache.get(arg_to_cache_key(arg),
- callback=invalidate_callback)
+ res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
@@ -519,7 +573,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
missing.add(arg)
if missing:
- # we need an observable deferred for each entry in the list,
+ # we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
@@ -527,8 +581,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
- observable = ObservableDeferred(deferred)
- cache.set(key, observable, callback=invalidate_callback)
+ cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res):
# the wrapped function has completed. It returns a
@@ -554,40 +607,62 @@ class CacheListDescriptor(_CacheDescriptorBase):
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = list(missing)
- cached_defers.append(defer.maybeDeferred(
- logcontext.preserve_fn(self.function_to_call),
- **args_to_call
- ).addCallbacks(complete_all, errback))
+ cached_defers.append(
+ defer.maybeDeferred(
+ preserve_fn(self.function_to_call), **args_to_call
+ ).addCallbacks(complete_all, errback)
+ )
if cached_defers:
- d = defer.gatherResults(
- cached_defers,
- consumeErrors=True,
- ).addCallbacks(
- lambda _: results,
- unwrapFirstError
+ d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
+ lambda _: results, unwrapFirstError
)
- return logcontext.make_deferred_yieldable(d)
+ return make_deferred_yieldable(d)
else:
- return results
+ return defer.succeed(results)
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
-class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
- # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
- # which namedtuple does for us (i.e. two _CacheContext are the same if
- # their caches and keys match). This is important in particular to
- # dedupe when we add callbacks to lru cache nodes, otherwise the number
- # of callbacks would grow.
- def invalidate(self):
- self.cache.invalidate(self.key)
+class _CacheContext:
+ """Holds cache information from the cached function higher in the calling order.
+
+ Can be used to invalidate the higher level cache entry if something changes
+ on a lower level.
+ """
+
+ _cache_context_objects = (
+ WeakValueDictionary()
+ ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
+
+ def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None
+ self._cache = cache
+ self._cache_key = cache_key
+
+ def invalidate(self): # type: () -> None
+ """Invalidates the cache entry referred to by the context."""
+ self._cache.invalidate(self._cache_key)
+
+ @classmethod
+ def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext
+ """Returns an instance constructed with the given arguments.
+
+ A new instance is only created if none already exists.
+ """
+
+ # We make sure there are no identical _CacheContext instances. This is
+ # important in particular to dedupe when we add callbacks to lru cache
+ # nodes, otherwise the number of callbacks would grow.
+ return cls._cache_context_objects.setdefault(
+ (cache, cache_key), cls(cache, cache_key)
+ )
-def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
- iterable=False):
+def cached(
+ max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@@ -598,8 +673,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
- cache_context=False, iterable=False):
+def cachedInlineCallbacks(
+ max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 6c0b5a4094..6834e6f3ae 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -35,6 +35,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
there.
value (dict): The full or partial dict value
"""
+
def __len__(self):
return len(self.value)
@@ -84,13 +85,15 @@ class DictionaryCache(object):
self.metrics.inc_hits()
if dict_keys is None:
- return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
+ return DictionaryEntry(
+ entry.full, entry.known_absent, dict(entry.value)
+ )
else:
- return DictionaryEntry(entry.full, entry.known_absent, {
- k: entry.value[k]
- for k in dict_keys
- if k in entry.value
- })
+ return DictionaryEntry(
+ entry.full,
+ entry.known_absent,
+ {k: entry.value[k] for k in dict_keys if k in entry.value},
+ )
self.metrics.inc_misses()
return DictionaryEntry(False, set(), {})
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index f369780277..cddf1ed515 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -28,8 +28,15 @@ SENTINEL = object()
class ExpiringCache(object):
- def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
- reset_expiry_on_get=False, iterable=False):
+ def __init__(
+ self,
+ cache_name,
+ clock,
+ max_len=0,
+ expiry_ms=0,
+ reset_expiry_on_get=False,
+ iterable=False,
+ ):
"""
Args:
cache_name (str): Name of this cache, used for logging.
@@ -67,8 +74,7 @@ class ExpiringCache(object):
def f():
return run_as_background_process(
- "prune_cache_%s" % self._cache_name,
- self._prune_cache,
+ "prune_cache_%s" % self._cache_name, self._prune_cache
)
self._clock.looping_call(f, self._expiry_ms / 2)
@@ -153,7 +159,9 @@ class ExpiringCache(object):
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
- self._cache_name, begin_length, len(self)
+ self._cache_name,
+ begin_length,
+ len(self),
)
def __len__(self):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index b684f24e7b..1536cb64f3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,8 +49,15 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
- def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
- evicted_callback=None):
+
+ def __init__(
+ self,
+ max_size,
+ keylen=1,
+ cache_type=dict,
+ size_callback=None,
+ evicted_callback=None,
+ ):
"""
Args:
max_size (int):
@@ -93,9 +100,12 @@ class LruCache(object):
cached_cache_len = [0]
if size_callback is not None:
+
def cache_len():
return cached_cache_len[0]
+
else:
+
def cache_len():
return len(cache)
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index afb03b2e1b..b68f9fe0d4 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -16,9 +16,9 @@ import logging
from twisted.internet import defer
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -35,12 +35,10 @@ class ResponseCache(object):
self.pending_result_cache = {} # Requests that haven't finished yet.
self.clock = hs.get_clock()
- self.timeout_sec = timeout_ms / 1000.
+ self.timeout_sec = timeout_ms / 1000.0
self._name = name
- self._metrics = register_cache(
- "response_cache", name, self
- )
+ self._metrics = register_cache("response_cache", name, self)
def size(self):
return len(self.pending_result_cache)
@@ -80,7 +78,7 @@ class ResponseCache(object):
*deferred* should run its callbacks in the sentinel logcontext (ie,
you should wrap normal synapse deferreds with
- logcontext.run_in_background).
+ synapse.logging.context.run_in_background).
Can return either a new Deferred (which also doesn't follow the synapse
logcontext rules), or, if *deferred* was already complete, the actual
@@ -100,8 +98,7 @@ class ResponseCache(object):
def remove(r):
if self.timeout_sec:
self.clock.call_later(
- self.timeout_sec,
- self.pending_result_cache.pop, key, None,
+ self.timeout_sec, self.pending_result_cache.pop, key, None
)
else:
self.pending_result_cache.pop(key, None)
@@ -124,7 +121,7 @@ class ResponseCache(object):
@defer.inlineCallbacks
def handle_request(request):
# etc
- defer.returnValue(result)
+ return result
result = yield response_cache.wrap(
key,
@@ -140,21 +137,22 @@ class ResponseCache(object):
*args: positional parameters to pass to the callback, if it is used
- **kwargs: named paramters to pass to the callback, if it is used
+ **kwargs: named parameters to pass to the callback, if it is used
Returns:
twisted.internet.defer.Deferred: yieldable result
"""
result = self.get(key)
if not result:
- logger.info("[%s]: no cached result for [%s], calculating new one",
- self._name, key)
+ logger.debug(
+ "[%s]: no cached result for [%s], calculating new one", self._name, key
+ )
d = run_in_background(callback, *args, **kwargs)
result = self.set(key, d)
elif not isinstance(result, defer.Deferred) or result.called:
- logger.info("[%s]: using completed cached result for [%s]",
- self._name, key)
+ logger.info("[%s]: using completed cached result for [%s]", self._name, key)
else:
- logger.info("[%s]: using incomplete cached result for [%s]",
- self._name, key)
+ logger.info(
+ "[%s]: using incomplete cached result for [%s]", self._name, key
+ )
return make_deferred_yieldable(result)
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
deleted file mode 100644
index 8318db8d2c..0000000000
--- a/synapse/util/caches/snapshot_cache.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.util.async_helpers import ObservableDeferred
-
-
-class SnapshotCache(object):
- """Cache for snapshots like the response of /initialSync.
- The response of initialSync only has to be a recent snapshot of the
- server state. It shouldn't matter to clients if it is a few minutes out
- of date.
-
- This caches a deferred response. Until the deferred completes it will be
- returned from the cache. This means that if the client retries the request
- while the response is still being computed, that original response will be
- used rather than trying to compute a new response.
-
- Once the deferred completes it will removed from the cache after 5 minutes.
- We delay removing it from the cache because a client retrying its request
- could race with us finishing computing the response.
-
- Rather than tracking precisely how long something has been in the cache we
- keep two generations of completed responses. Every 5 minutes discard the
- old generation, move the new generation to the old generation, and set the
- new generation to be empty. This means that a result will be in the cache
- somewhere between 5 and 10 minutes.
- """
-
- DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes.
-
- def __init__(self):
- self.pending_result_cache = {} # Request that haven't finished yet.
- self.prev_result_cache = {} # The older requests that have finished.
- self.next_result_cache = {} # The newer requests that have finished.
- self.time_last_rotated_ms = 0
-
- def rotate(self, time_now_ms):
- # Rotate once if the cache duration has passed since the last rotation.
- if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
- self.prev_result_cache = self.next_result_cache
- self.next_result_cache = {}
- self.time_last_rotated_ms += self.DURATION_MS
-
- # Rotate again if the cache duration has passed twice since the last
- # rotation.
- if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
- self.prev_result_cache = self.next_result_cache
- self.next_result_cache = {}
- self.time_last_rotated_ms = time_now_ms
-
- def get(self, time_now_ms, key):
- self.rotate(time_now_ms)
- # This cache is intended to deduplicate requests, so we expect it to be
- # missed most of the time. So we just lookup the key in all of the
- # dictionaries rather than trying to short circuit the lookup if the
- # key is found.
- result = self.prev_result_cache.get(key)
- result = self.next_result_cache.get(key, result)
- result = self.pending_result_cache.get(key, result)
- if result is not None:
- return result.observe()
- else:
- return None
-
- def set(self, time_now_ms, key, deferred):
- self.rotate(time_now_ms)
-
- result = ObservableDeferred(deferred)
-
- self.pending_result_cache[key] = result
-
- def shuffle_along(r):
- # When the deferred completes we shuffle it along to the first
- # generation of the result cache. So that it will eventually
- # expire from the rotation of that cache.
- self.next_result_cache[key] = result
- self.pending_result_cache.pop(key, None)
- return r
-
- result.addBoth(shuffle_along)
-
- return result.observe()
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 625aedc940..235f64049c 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -77,9 +77,8 @@ class StreamChangeCache(object):
if stream_pos >= self._earliest_known_stream_pos:
changed_entities = {
- self._cache[k] for k in self._cache.islice(
- start=self._cache.bisect_right(stream_pos),
- )
+ self._cache[k]
+ for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
}
result = changed_entities.intersection(entities)
@@ -114,8 +113,10 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
- return [self._cache[k] for k in self._cache.islice(
- start=self._cache.bisect_right(stream_pos))]
+ return [
+ self._cache[k]
+ for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
+ ]
else:
return None
@@ -136,7 +137,7 @@ class StreamChangeCache(object):
while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(
- k, self._earliest_known_stream_pos,
+ k, self._earliest_known_stream_pos
)
self._entity_to_key.pop(r, None)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index dd4c9e6067..2ea4e4e911 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,3 +1,5 @@
+from typing import Dict
+
from six import itervalues
SENTINEL = object()
@@ -9,9 +11,10 @@ class TreeCache(object):
efficiently.
Keys must be tuples.
"""
+
def __init__(self):
self.size = 0
- self.root = {}
+ self.root = {} # type: Dict
def __setitem__(self, key, value):
return self.set(key, value)
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 5ba1862506..99646c7cf0 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -55,7 +55,7 @@ class TTLCache(object):
if e != SENTINEL:
self._expiry_list.remove(e)
- entry = _CacheEntry(expiry_time=expiry, key=key, value=value)
+ entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value)
self._data[key] = entry
self._expiry_list.add(entry)
@@ -87,7 +87,8 @@ class TTLCache(object):
key: key to look up
Returns:
- Tuple[Any, float]: the value from the cache, and the expiry time
+ Tuple[Any, float, float]: the value from the cache, the expiry time
+ and the TTL
Raises:
KeyError if the entry is not found
@@ -99,7 +100,7 @@ class TTLCache(object):
self._metrics.inc_misses()
raise
self._metrics.inc_hits()
- return e.value, e.expiry_time
+ return e.value, e.expiry_time, e.ttl
def pop(self, key, default=SENTINEL):
"""Remove a value from the cache
@@ -155,7 +156,9 @@ class TTLCache(object):
@attr.s(frozen=True, slots=True)
class _CacheEntry(object):
"""TTLCache entry"""
+
# expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time = attr.ib()
+ ttl = attr.ib()
key = attr.ib()
value = attr.ib()
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index e14c8bdfda..45af8d3eeb 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -17,8 +17,8 @@ import logging
from twisted.internet import defer
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -51,9 +51,7 @@ class Distributor(object):
if name in self.signals:
raise KeyError("%r already has a signal named %s" % (self, name))
- self.signals[name] = Signal(
- name,
- )
+ self.signals[name] = Signal(name)
if name in self.pre_registration:
signal = self.signals[name]
@@ -78,11 +76,7 @@ class Distributor(object):
if name not in self.signals:
raise KeyError("%r does not have a signal named %s" % (self, name))
- run_as_background_process(
- name,
- self.signals[name].fire,
- *args, **kwargs
- )
+ run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
class Signal(object):
@@ -118,22 +112,23 @@ class Signal(object):
def eb(failure):
logger.warning(
"%s signal observer %s failed: %r",
- self.name, observer, failure,
+ self.name,
+ observer,
+ failure,
exc_info=(
failure.type,
failure.value,
- failure.getTracebackObject()))
+ failure.getTracebackObject(),
+ ),
+ )
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
- deferreds = [
- run_in_background(do, o)
- for o in self.observers
- ]
+ deferreds = [run_in_background(do, o) for o in self.observers]
- return make_deferred_yieldable(defer.gatherResults(
- deferreds, consumeErrors=True,
- ))
+ return make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
def __repr__(self):
return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 629ed44149..8b17d1c8b8 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -17,7 +17,7 @@ from six.moves import queue
from twisted.internet import threads
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable, run_in_background
class BackgroundFileConsumer(object):
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 014edea971..f2ccd5e7c6 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -30,7 +30,7 @@ def freeze(o):
return o
try:
- return tuple([freeze(i) for i in o])
+ return tuple(freeze(i) for i in o)
except TypeError:
pass
@@ -60,11 +60,10 @@ def _handle_frozendict(obj):
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
return obj._dict
- raise TypeError('Object of type %s is not JSON serializable' %
- obj.__class__.__name__)
+ raise TypeError(
+ "Object of type %s is not JSON serializable" % obj.__class__.__name__
+ )
# A JSONEncoder which is capable of encoding frozendics without barfing
-frozendict_json_encoder = json.JSONEncoder(
- default=_handle_frozendict,
-)
+frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict)
diff --git a/synapse/util/hash.py b/synapse/util/hash.py
new file mode 100644
index 0000000000..359168704e
--- /dev/null
+++ b/synapse/util/hash.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+
+import unpaddedbase64
+
+
+def sha256_and_url_safe_base64(input_text):
+ """SHA256 hash an input string, encode the digest as url-safe base64, and
+ return
+
+ :param input_text: string to hash
+ :type input_text: str
+
+ :returns a sha256 hashed and url-safe base64 encoded digest
+ :rtype: str
+ """
+ digest = hashlib.sha256(input_text.encode()).digest()
+ return unpaddedbase64.encode_base64(digest, urlsafe=True)
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 2d7ddc1cbe..3c0e8469f3 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
def create_resource_tree(desired_tree, root_resource):
- """Create the resource tree for this Home Server.
+ """Create the resource tree for this homeserver.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
@@ -45,7 +45,7 @@ def create_resource_tree(desired_tree, root_resource):
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
- for path_seg in full_path.split(b'/')[1:-1]:
+ for path_seg in full_path.split(b"/")[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = NoResource()
@@ -60,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
# ===========================
# now attach the actual desired resource
- last_path_seg = full_path.split(b'/')[-1]
+ last_path_seg = full_path.split(b"/")[-1]
# if there is already a resource here, thieve its children and
# replace it
@@ -70,9 +70,7 @@ def create_resource_tree(desired_tree, root_resource):
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
- child_res_id = _resource_id(
- existing_dummy_resource, child_name
- )
+ child_res_id = _resource_id(existing_dummy_resource, child_name)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
new file mode 100644
index 0000000000..06faeebe7f
--- /dev/null
+++ b/synapse/util/iterutils.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from itertools import islice
+from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+
+T = TypeVar("T")
+
+
+def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]:
+ """batch an iterable up into tuples with a maximum size
+
+ Args:
+ iterable (iterable): the iterable to slice
+ size (int): the maximum batch size
+
+ Returns:
+ an iterator over the chunks
+ """
+ # make sure we can deal with iterables like lists too
+ sourceiter = iter(iterable)
+ # call islice until it returns an empty tuple
+ return iter(lambda: tuple(islice(sourceiter, size)), ())
+
+
+ISeq = TypeVar("ISeq", bound=Sequence, covariant=True)
+
+
+def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
+ """Split the given sequence into chunks of the given size
+
+ The last chunk may be shorter than the given size.
+
+ If the input is empty, no chunks are returned.
+ """
+ return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index d668e5a6b8..6dce03dd3a 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -70,7 +70,8 @@ class JsonEncodedObject(object):
dict
"""
d = {
- k: _encode(v) for (k, v) in self.__dict__.items()
+ k: _encode(v)
+ for (k, v) in self.__dict__.items()
if k in self.valid_keys and k not in self.internal_keys
}
d.update(self.unrecognized_keys)
@@ -78,7 +79,8 @@ class JsonEncodedObject(object):
def get_internal_dict(self):
d = {
- k: _encode(v, internal=True) for (k, v) in self.__dict__.items()
+ k: _encode(v, internal=True)
+ for (k, v) in self.__dict__.items()
if k in self.valid_keys
}
d.update(self.unrecognized_keys)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index fe412355d8..40e5c10a49 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,633 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-""" Thread-local-alike tracking of log contexts within synapse
-
-This module provides objects and utilities for tracking contexts through
-synapse code, so that log lines can include a request identifier, and so that
-CPU and database activity can be accounted for against the request that caused
-them.
-
-See doc/log_contexts.rst for details on how this works.
+"""
+Backwards compatibility re-exports of ``synapse.logging.context`` functionality.
"""
-import logging
-import threading
-
-from twisted.internet import defer, threads
-
-logger = logging.getLogger(__name__)
-
-try:
- import resource
-
- # Python doesn't ship with a definition of RUSAGE_THREAD but it's defined
- # to be 1 on linux so we hard code it.
- RUSAGE_THREAD = 1
-
- # If the system doesn't support RUSAGE_THREAD then this should throw an
- # exception.
- resource.getrusage(RUSAGE_THREAD)
-
- def get_thread_resource_usage():
- return resource.getrusage(RUSAGE_THREAD)
-except Exception:
- # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
- # won't track resource usage by returning None.
- def get_thread_resource_usage():
- return None
-
-
-class ContextResourceUsage(object):
- """Object for tracking the resources used by a log context
-
- Attributes:
- ru_utime (float): user CPU time (in seconds)
- ru_stime (float): system CPU time (in seconds)
- db_txn_count (int): number of database transactions done
- db_sched_duration_sec (float): amount of time spent waiting for a
- database connection
- db_txn_duration_sec (float): amount of time spent doing database
- transactions (excluding scheduling time)
- evt_db_fetch_count (int): number of events requested from the database
- """
-
- __slots__ = [
- "ru_stime", "ru_utime",
- "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
- "evt_db_fetch_count",
- ]
-
- def __init__(self, copy_from=None):
- """Create a new ContextResourceUsage
-
- Args:
- copy_from (ContextResourceUsage|None): if not None, an object to
- copy stats from
- """
- if copy_from is None:
- self.reset()
- else:
- self.ru_utime = copy_from.ru_utime
- self.ru_stime = copy_from.ru_stime
- self.db_txn_count = copy_from.db_txn_count
-
- self.db_txn_duration_sec = copy_from.db_txn_duration_sec
- self.db_sched_duration_sec = copy_from.db_sched_duration_sec
- self.evt_db_fetch_count = copy_from.evt_db_fetch_count
-
- def copy(self):
- return ContextResourceUsage(copy_from=self)
-
- def reset(self):
- self.ru_stime = 0.
- self.ru_utime = 0.
- self.db_txn_count = 0
-
- self.db_txn_duration_sec = 0
- self.db_sched_duration_sec = 0
- self.evt_db_fetch_count = 0
-
- def __repr__(self):
- return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
- "db_txn_count='%r', db_txn_duration_sec='%r', "
- "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % (
- self.ru_stime,
- self.ru_utime,
- self.db_txn_count,
- self.db_txn_duration_sec,
- self.db_sched_duration_sec,
- self.evt_db_fetch_count,)
-
- def __iadd__(self, other):
- """Add another ContextResourceUsage's stats to this one's.
-
- Args:
- other (ContextResourceUsage): the other resource usage object
- """
- self.ru_utime += other.ru_utime
- self.ru_stime += other.ru_stime
- self.db_txn_count += other.db_txn_count
- self.db_txn_duration_sec += other.db_txn_duration_sec
- self.db_sched_duration_sec += other.db_sched_duration_sec
- self.evt_db_fetch_count += other.evt_db_fetch_count
- return self
-
- def __isub__(self, other):
- self.ru_utime -= other.ru_utime
- self.ru_stime -= other.ru_stime
- self.db_txn_count -= other.db_txn_count
- self.db_txn_duration_sec -= other.db_txn_duration_sec
- self.db_sched_duration_sec -= other.db_sched_duration_sec
- self.evt_db_fetch_count -= other.evt_db_fetch_count
- return self
-
- def __add__(self, other):
- res = ContextResourceUsage(copy_from=self)
- res += other
- return res
-
- def __sub__(self, other):
- res = ContextResourceUsage(copy_from=self)
- res -= other
- return res
-
-
-class LoggingContext(object):
- """Additional context for log formatting. Contexts are scoped within a
- "with" block.
-
- If a parent is given when creating a new context, then:
- - logging fields are copied from the parent to the new context on entry
- - when the new context exits, the cpu usage stats are copied from the
- child to the parent
-
- Args:
- name (str): Name for the context for debugging.
- parent_context (LoggingContext|None): The parent of the new context
- """
-
- __slots__ = [
- "previous_context", "name", "parent_context",
- "_resource_usage",
- "usage_start",
- "main_thread", "alive",
- "request", "tag",
- ]
-
- thread_local = threading.local()
-
- class Sentinel(object):
- """Sentinel to represent the root context"""
-
- __slots__ = []
-
- def __str__(self):
- return "sentinel"
-
- def copy_to(self, record):
- pass
-
- def start(self):
- pass
-
- def stop(self):
- pass
-
- def add_database_transaction(self, duration_sec):
- pass
-
- def add_database_scheduled(self, sched_sec):
- pass
-
- def record_event_fetch(self, event_count):
- pass
-
- def __nonzero__(self):
- return False
- __bool__ = __nonzero__ # python3
-
- sentinel = Sentinel()
-
- def __init__(self, name=None, parent_context=None, request=None):
- self.previous_context = LoggingContext.current_context()
- self.name = name
-
- # track the resources used by this context so far
- self._resource_usage = ContextResourceUsage()
-
- # If alive has the thread resource usage when the logcontext last
- # became active.
- self.usage_start = None
-
- self.main_thread = threading.current_thread()
- self.request = None
- self.tag = ""
- self.alive = True
-
- self.parent_context = parent_context
-
- if self.parent_context is not None:
- self.parent_context.copy_to(self)
-
- if request is not None:
- # the request param overrides the request from the parent context
- self.request = request
-
- def __str__(self):
- if self.request:
- return str(self.request)
- return "%s@%x" % (self.name, id(self))
-
- @classmethod
- def current_context(cls):
- """Get the current logging context from thread local storage
-
- Returns:
- LoggingContext: the current logging context
- """
- return getattr(cls.thread_local, "current_context", cls.sentinel)
-
- @classmethod
- def set_current_context(cls, context):
- """Set the current logging context in thread local storage
- Args:
- context(LoggingContext): The context to activate.
- Returns:
- The context that was previously active
- """
- current = cls.current_context()
-
- if current is not context:
- current.stop()
- cls.thread_local.current_context = context
- context.start()
- return current
-
- def __enter__(self):
- """Enters this logging context into thread local storage"""
- old_context = self.set_current_context(self)
- if self.previous_context != old_context:
- logger.warn(
- "Expected previous context %r, found %r",
- self.previous_context, old_context
- )
- self.alive = True
-
- return self
-
- def __exit__(self, type, value, traceback):
- """Restore the logging context in thread local storage to the state it
- was before this context was entered.
- Returns:
- None to avoid suppressing any exceptions that were thrown.
- """
- current = self.set_current_context(self.previous_context)
- if current is not self:
- if current is self.sentinel:
- logger.warning("Expected logging context %s was lost", self)
- else:
- logger.warning(
- "Expected logging context %s but found %s", self, current
- )
- self.previous_context = None
- self.alive = False
-
- # if we have a parent, pass our CPU usage stats on
- if (
- self.parent_context is not None
- and hasattr(self.parent_context, '_resource_usage')
- ):
- self.parent_context._resource_usage += self._resource_usage
-
- # reset them in case we get entered again
- self._resource_usage.reset()
-
- def copy_to(self, record):
- """Copy logging fields from this context to a log record or
- another LoggingContext
- """
-
- # 'request' is the only field we currently use in the logger, so that's
- # all we need to copy
- record.request = self.request
-
- def start(self):
- if threading.current_thread() is not self.main_thread:
- logger.warning("Started logcontext %s on different thread", self)
- return
-
- # If we haven't already started record the thread resource usage so
- # far
- if not self.usage_start:
- self.usage_start = get_thread_resource_usage()
-
- def stop(self):
- if threading.current_thread() is not self.main_thread:
- logger.warning("Stopped logcontext %s on different thread", self)
- return
-
- # When we stop, let's record the cpu used since we started
- if not self.usage_start:
- logger.warning(
- "Called stop on logcontext %s without calling start", self,
- )
- return
-
- usage_end = get_thread_resource_usage()
-
- self._resource_usage.ru_utime += usage_end.ru_utime - self.usage_start.ru_utime
- self._resource_usage.ru_stime += usage_end.ru_stime - self.usage_start.ru_stime
-
- self.usage_start = None
-
- def get_resource_usage(self):
- """Get resources used by this logcontext so far.
-
- Returns:
- ContextResourceUsage: a *copy* of the object tracking resource
- usage so far
- """
- # we always return a copy, for consistency
- res = self._resource_usage.copy()
-
- # If we are on the correct thread and we're currently running then we
- # can include resource usage so far.
- is_main_thread = threading.current_thread() is self.main_thread
- if self.alive and self.usage_start and is_main_thread:
- current = get_thread_resource_usage()
- res.ru_utime += current.ru_utime - self.usage_start.ru_utime
- res.ru_stime += current.ru_stime - self.usage_start.ru_stime
-
- return res
-
- def add_database_transaction(self, duration_sec):
- self._resource_usage.db_txn_count += 1
- self._resource_usage.db_txn_duration_sec += duration_sec
-
- def add_database_scheduled(self, sched_sec):
- """Record a use of the database pool
-
- Args:
- sched_sec (float): number of seconds it took us to get a
- connection
- """
- self._resource_usage.db_sched_duration_sec += sched_sec
-
- def record_event_fetch(self, event_count):
- """Record a number of events being fetched from the db
-
- Args:
- event_count (int): number of events being fetched
- """
- self._resource_usage.evt_db_fetch_count += event_count
-
-
-class LoggingContextFilter(logging.Filter):
- """Logging filter that adds values from the current logging context to each
- record.
- Args:
- **defaults: Default values to avoid formatters complaining about
- missing fields
- """
- def __init__(self, **defaults):
- self.defaults = defaults
-
- def filter(self, record):
- """Add each fields from the logging contexts to the record.
- Returns:
- True to include the record in the log output.
- """
- context = LoggingContext.current_context()
- for key, value in self.defaults.items():
- setattr(record, key, value)
-
- # context should never be None, but if it somehow ends up being, then
- # we end up in a death spiral of infinite loops, so let's check, for
- # robustness' sake.
- if context is not None:
- context.copy_to(record)
-
- return True
-
-
-class PreserveLoggingContext(object):
- """Captures the current logging context and restores it when the scope is
- exited. Used to restore the context after a function using
- @defer.inlineCallbacks is resumed by a callback from the reactor."""
-
- __slots__ = ["current_context", "new_context", "has_parent"]
-
- def __init__(self, new_context=None):
- if new_context is None:
- new_context = LoggingContext.sentinel
- self.new_context = new_context
-
- def __enter__(self):
- """Captures the current logging context"""
- self.current_context = LoggingContext.set_current_context(
- self.new_context
- )
-
- if self.current_context:
- self.has_parent = self.current_context.previous_context is not None
- if not self.current_context.alive:
- logger.debug(
- "Entering dead context: %s",
- self.current_context,
- )
-
- def __exit__(self, type, value, traceback):
- """Restores the current logging context"""
- context = LoggingContext.set_current_context(self.current_context)
-
- if context != self.new_context:
- if context is LoggingContext.sentinel:
- logger.warning("Expected logging context %s was lost", self.new_context)
- else:
- logger.warning(
- "Expected logging context %s but found %s",
- self.new_context,
- context,
- )
-
- if self.current_context is not LoggingContext.sentinel:
- if not self.current_context.alive:
- logger.debug(
- "Restoring dead context: %s",
- self.current_context,
- )
-
-
-def nested_logging_context(suffix, parent_context=None):
- """Creates a new logging context as a child of another.
-
- The nested logging context will have a 'request' made up of the parent context's
- request, plus the given suffix.
-
- CPU/db usage stats will be added to the parent context's on exit.
-
- Normal usage looks like:
-
- with nested_logging_context(suffix):
- # ... do stuff
-
- Args:
- suffix (str): suffix to add to the parent context's 'request'.
- parent_context (LoggingContext|None): parent context. Will use the current context
- if None.
-
- Returns:
- LoggingContext: new logging context.
- """
- if parent_context is None:
- parent_context = LoggingContext.current_context()
- return LoggingContext(
- parent_context=parent_context,
- request=parent_context.request + "-" + suffix,
- )
-
-
-def preserve_fn(f):
- """Function decorator which wraps the function with run_in_background"""
- def g(*args, **kwargs):
- return run_in_background(f, *args, **kwargs)
- return g
-
-
-def run_in_background(f, *args, **kwargs):
- """Calls a function, ensuring that the current context is restored after
- return from the function, and that the sentinel context is set once the
- deferred returned by the function completes.
-
- Useful for wrapping functions that return a deferred which you don't yield
- on (for instance because you want to pass it to deferred.gatherResults()).
-
- Note that if you completely discard the result, you should make sure that
- `f` doesn't raise any deferred exceptions, otherwise a scary-looking
- CRITICAL error about an unhandled error will be logged without much
- indication about where it came from.
- """
- current = LoggingContext.current_context()
- try:
- res = f(*args, **kwargs)
- except: # noqa: E722
- # the assumption here is that the caller doesn't want to be disturbed
- # by synchronous exceptions, so let's turn them into Failures.
- return defer.fail()
-
- if not isinstance(res, defer.Deferred):
- return res
-
- if res.called and not res.paused:
- # The function should have maintained the logcontext, so we can
- # optimise out the messing about
- return res
-
- # The function may have reset the context before returning, so
- # we need to restore it now.
- ctx = LoggingContext.set_current_context(current)
-
- # The original context will be restored when the deferred
- # completes, but there is nothing waiting for it, so it will
- # get leaked into the reactor or some other function which
- # wasn't expecting it. We therefore need to reset the context
- # here.
- #
- # (If this feels asymmetric, consider it this way: we are
- # effectively forking a new thread of execution. We are
- # probably currently within a ``with LoggingContext()`` block,
- # which is supposed to have a single entry and exit point. But
- # by spawning off another deferred, we are effectively
- # adding a new exit point.)
- res.addBoth(_set_context_cb, ctx)
- return res
-
-
-def make_deferred_yieldable(deferred):
- """Given a deferred, make it follow the Synapse logcontext rules:
-
- If the deferred has completed (or is not actually a Deferred), essentially
- does nothing (just returns another completed deferred with the
- result/failure).
-
- If the deferred has not yet completed, resets the logcontext before
- returning a deferred. Then, when the deferred completes, restores the
- current logcontext before running callbacks/errbacks.
-
- (This is more-or-less the opposite operation to run_in_background.)
- """
- if not isinstance(deferred, defer.Deferred):
- return deferred
-
- if deferred.called and not deferred.paused:
- # it looks like this deferred is ready to run any callbacks we give it
- # immediately. We may as well optimise out the logcontext faffery.
- return deferred
-
- # ok, we can't be sure that a yield won't block, so let's reset the
- # logcontext, and add a callback to the deferred to restore it.
- prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
- deferred.addBoth(_set_context_cb, prev_context)
- return deferred
-
-
-def _set_context_cb(result, context):
- """A callback function which just sets the logging context"""
- LoggingContext.set_current_context(context)
- return result
-
-
-def defer_to_thread(reactor, f, *args, **kwargs):
- """
- Calls the function `f` using a thread from the reactor's default threadpool and
- returns the result as a Deferred.
-
- Creates a new logcontext for `f`, which is created as a child of the current
- logcontext (so its CPU usage metrics will get attributed to the current
- logcontext). `f` should preserve the logcontext it is given.
-
- The result deferred follows the Synapse logcontext rules: you should `yield`
- on it.
-
- Args:
- reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
- the Deferred will be invoked, and whose threadpool we should use for the
- function.
-
- Normally this will be hs.get_reactor().
-
- f (callable): The function to call.
-
- args: positional arguments to pass to f.
-
- kwargs: keyword arguments to pass to f.
-
- Returns:
- Deferred: A Deferred which fires a callback with the result of `f`, or an
- errback if `f` throws an exception.
- """
- return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
-
-
-def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
- """
- A wrapper for twisted.internet.threads.deferToThreadpool, which handles
- logcontexts correctly.
-
- Calls the function `f` using a thread from the given threadpool and returns
- the result as a Deferred.
-
- Creates a new logcontext for `f`, which is created as a child of the current
- logcontext (so its CPU usage metrics will get attributed to the current
- logcontext). `f` should preserve the logcontext it is given.
-
- The result deferred follows the Synapse logcontext rules: you should `yield`
- on it.
-
- Args:
- reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
- the Deferred will be invoked. Normally this will be hs.get_reactor().
-
- threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for
- running `f`. Normally this will be hs.get_reactor().getThreadPool().
-
- f (callable): The function to call.
-
- args: positional arguments to pass to f.
-
- kwargs: keyword arguments to pass to f.
-
- Returns:
- Deferred: A Deferred which fires a callback with the result of `f`, or an
- errback if `f` throws an exception.
- """
- logcontext = LoggingContext.current_context()
-
- def g():
- with LoggingContext(parent_context=logcontext):
- return f(*args, **kwargs)
-
- return make_deferred_yieldable(
- threads.deferToThreadPool(reactor, threadpool, g)
- )
+from synapse.logging.context import (
+ LoggingContext,
+ LoggingContextFilter,
+ PreserveLoggingContext,
+ defer_to_thread,
+ make_deferred_yieldable,
+ nested_logging_context,
+ preserve_fn,
+ run_in_background,
+)
+
+__all__ = [
+ "defer_to_thread",
+ "LoggingContext",
+ "LoggingContextFilter",
+ "make_deferred_yieldable",
+ "nested_logging_context",
+ "preserve_fn",
+ "PreserveLoggingContext",
+ "run_in_background",
+]
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
index a46bc47ce3..320e8f8174 100644
--- a/synapse/util/logformatter.py
+++ b/synapse/util/logformatter.py
@@ -1,5 +1,4 @@
-# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,40 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Backwards compatibility re-exports of ``synapse.logging.formatter`` functionality.
+"""
-import logging
-import traceback
+from synapse.logging.formatter import LogFormatter
-from six import StringIO
-
-
-class LogFormatter(logging.Formatter):
- """Log formatter which gives more detail for exceptions
-
- This is the same as the standard log formatter, except that when logging
- exceptions [typically via log.foo("msg", exc_info=1)], it prints the
- sequence that led up to the point at which the exception was caught.
- (Normally only stack frames between the point the exception was raised and
- where it was caught are logged).
- """
- def __init__(self, *args, **kwargs):
- super(LogFormatter, self).__init__(*args, **kwargs)
-
- def formatException(self, ei):
- sio = StringIO()
- (typ, val, tb) = ei
-
- # log the stack above the exception capture point if possible, but
- # check that we actually have an f_back attribute to work around
- # https://twistedmatrix.com/trac/ticket/9305
-
- if tb and hasattr(tb.tb_frame, 'f_back'):
- sio.write("Capture point (most recent call last):\n")
- traceback.print_stack(tb.tb_frame.f_back, None, sio)
-
- traceback.print_exception(typ, val, tb, None, sio)
- s = sio.getvalue()
- sio.close()
- if s[-1:] == "\n":
- s = s[:-1]
- return s
+__all__ = ["LogFormatter"]
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
deleted file mode 100644
index ef31458226..0000000000
--- a/synapse/util/logutils.py
+++ /dev/null
@@ -1,209 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import inspect
-import logging
-import time
-from functools import wraps
-from inspect import getcallargs
-
-from six import PY3
-
-_TIME_FUNC_ID = 0
-
-
-def _log_debug_as_f(f, msg, msg_args):
- name = f.__module__
- logger = logging.getLogger(name)
-
- if logger.isEnabledFor(logging.DEBUG):
- if PY3:
- lineno = f.__code__.co_firstlineno
- pathname = f.__code__.co_filename
- else:
- lineno = f.func_code.co_firstlineno
- pathname = f.func_code.co_filename
-
- record = logging.LogRecord(
- name=name,
- level=logging.DEBUG,
- pathname=pathname,
- lineno=lineno,
- msg=msg,
- args=msg_args,
- exc_info=None
- )
-
- logger.handle(record)
-
-
-def log_function(f):
- """ Function decorator that logs every call to that function.
- """
- func_name = f.__name__
-
- @wraps(f)
- def wrapped(*args, **kwargs):
- name = f.__module__
- logger = logging.getLogger(name)
- level = logging.DEBUG
-
- if logger.isEnabledFor(level):
- bound_args = getcallargs(f, *args, **kwargs)
-
- def format(value):
- r = str(value)
- if len(r) > 50:
- r = r[:50] + "..."
- return r
-
- func_args = [
- "%s=%s" % (k, format(v)) for k, v in bound_args.items()
- ]
-
- msg_args = {
- "func_name": func_name,
- "args": ", ".join(func_args)
- }
-
- _log_debug_as_f(
- f,
- "Invoked '%(func_name)s' with args: %(args)s",
- msg_args
- )
-
- return f(*args, **kwargs)
-
- wrapped.__name__ = func_name
- return wrapped
-
-
-def time_function(f):
- func_name = f.__name__
-
- @wraps(f)
- def wrapped(*args, **kwargs):
- global _TIME_FUNC_ID
- id = _TIME_FUNC_ID
- _TIME_FUNC_ID += 1
-
- start = time.clock()
-
- try:
- _log_debug_as_f(
- f,
- "[FUNC START] {%s-%d}",
- (func_name, id),
- )
-
- r = f(*args, **kwargs)
- finally:
- end = time.clock()
- _log_debug_as_f(
- f,
- "[FUNC END] {%s-%d} %.3f sec",
- (func_name, id, end - start,),
- )
-
- return r
-
- return wrapped
-
-
-def trace_function(f):
- func_name = f.__name__
- linenum = f.func_code.co_firstlineno
- pathname = f.func_code.co_filename
-
- @wraps(f)
- def wrapped(*args, **kwargs):
- name = f.__module__
- logger = logging.getLogger(name)
- level = logging.DEBUG
-
- s = inspect.currentframe().f_back
-
- to_print = [
- "\t%s:%s %s. Args: args=%s, kwargs=%s" % (
- pathname, linenum, func_name, args, kwargs
- )
- ]
- while s:
- if True or s.f_globals["__name__"].startswith("synapse"):
- filename, lineno, function, _, _ = inspect.getframeinfo(s)
- args_string = inspect.formatargvalues(*inspect.getargvalues(s))
-
- to_print.append(
- "\t%s:%d %s. Args: %s" % (
- filename, lineno, function, args_string
- )
- )
-
- s = s.f_back
-
- msg = "\nTraceback for %s:\n" % (func_name,) + "\n".join(to_print)
-
- record = logging.LogRecord(
- name=name,
- level=level,
- pathname=pathname,
- lineno=lineno,
- msg=msg,
- args=None,
- exc_info=None
- )
-
- logger.handle(record)
-
- return f(*args, **kwargs)
-
- wrapped.__name__ = func_name
- return wrapped
-
-
-def get_previous_frames():
- s = inspect.currentframe().f_back.f_back
- to_return = []
- while s:
- if s.f_globals["__name__"].startswith("synapse"):
- filename, lineno, function, _, _ = inspect.getframeinfo(s)
- args_string = inspect.formatargvalues(*inspect.getargvalues(s))
-
- to_return.append("{{ %s:%d %s - Args: %s }}" % (
- filename, lineno, function, args_string
- ))
-
- s = s.f_back
-
- return ", ". join(to_return)
-
-
-def get_previous_frame(ignore=[]):
- s = inspect.currentframe().f_back.f_back
-
- while s:
- if s.f_globals["__name__"].startswith("synapse"):
- if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
- filename, lineno, function, _, _ = inspect.getframeinfo(s)
- args_string = inspect.formatargvalues(*inspect.getargvalues(s))
-
- return "{{ %s:%d %s - Args: %s }}" % (
- filename, lineno, function, args_string
- )
-
- s = s.f_back
-
- return None
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 628a2962d9..631654f297 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -74,27 +74,25 @@ def manhole(username, password, globals):
twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
"""
if not isinstance(password, bytes):
- password = password.encode('ascii')
+ password = password.encode("ascii")
- checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
- **{username: password}
- )
+ checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password})
rlm = manhole_ssh.TerminalRealm()
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
- SynapseManhole,
- dict(globals, __name__="__console__")
+ SynapseManhole, dict(globals, __name__="__console__")
)
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
- factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY)
- factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY)
+ factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY)
+ factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY)
return factory
class SynapseManhole(ColoredManhole):
"""Overrides connectionMade to create our own ManholeInterpreter"""
+
def connectionMade(self):
super(SynapseManhole, self).connectionMade()
@@ -127,7 +125,7 @@ class SynapseManholeInterpreter(ManholeInterpreter):
value = SyntaxError(msg, (filename, lineno, offset, line))
sys.last_value = value
lines = traceback.format_exception_only(type, value)
- self.write(''.join(lines))
+ self.write("".join(lines))
def showtraceback(self):
"""Display the exception that just occurred.
@@ -140,6 +138,6 @@ class SynapseManholeInterpreter(ManholeInterpreter):
try:
# We remove the first stack item because it is our own code.
lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next)
- self.write(''.join(lines))
+ self.write("".join(lines))
finally:
last_tb = ei = None
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4b4ac5f6c7..7b18455469 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
from functools import wraps
@@ -20,8 +21,8 @@ from prometheus_client import Counter
from twisted.internet import defer
+from synapse.logging.context import LoggingContext
from synapse.metrics import InFlightGauge
-from synapse.util.logcontext import LoggingContext
logger = logging.getLogger(__name__)
@@ -30,108 +31,108 @@ block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"])
block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"])
block_ru_utime = Counter(
- "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]
+)
block_ru_stime = Counter(
- "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]
+)
block_db_txn_count = Counter(
- "synapse_util_metrics_block_db_txn_count", "", ["block_name"])
+ "synapse_util_metrics_block_db_txn_count", "", ["block_name"]
+)
# seconds spent waiting for db txns, excluding scheduling time, in this block
block_db_txn_duration = Counter(
- "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]
+)
# seconds spent waiting for a db connection, in this block
block_db_sched_duration = Counter(
- "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
+)
# Tracks the number of blocks currently active
in_flight = InFlightGauge(
- "synapse_util_metrics_block_in_flight", "",
+ "synapse_util_metrics_block_in_flight",
+ "",
labels=["block_name"],
sub_metrics=["real_time_max", "real_time_sum"],
)
-def measure_func(name):
+def measure_func(name=None):
def wrapper(func):
- @wraps(func)
- @defer.inlineCallbacks
- def measured_func(self, *args, **kwargs):
- with Measure(self.clock, name):
- r = yield func(self, *args, **kwargs)
- defer.returnValue(r)
+ block_name = func.__name__ if name is None else name
+
+ if inspect.iscoroutinefunction(func):
+
+ @wraps(func)
+ async def measured_func(self, *args, **kwargs):
+ with Measure(self.clock, block_name):
+ r = await func(self, *args, **kwargs)
+ return r
+
+ else:
+
+ @wraps(func)
+ @defer.inlineCallbacks
+ def measured_func(self, *args, **kwargs):
+ with Measure(self.clock, block_name):
+ r = yield func(self, *args, **kwargs)
+ return r
+
return measured_func
+
return wrapper
class Measure(object):
__slots__ = [
- "clock", "name", "start_context", "start",
- "created_context",
- "start_usage",
+ "clock",
+ "name",
+ "_logging_context",
+ "start",
]
def __init__(self, clock, name):
self.clock = clock
self.name = name
- self.start_context = None
+ self._logging_context = None
self.start = None
- self.created_context = False
def __enter__(self):
- self.start = self.clock.time()
- self.start_context = LoggingContext.current_context()
- if not self.start_context:
- self.start_context = LoggingContext("Measure")
- self.start_context.__enter__()
- self.created_context = True
-
- self.start_usage = self.start_context.get_resource_usage()
+ if self._logging_context:
+ raise RuntimeError("Measure() objects cannot be re-used")
+ self.start = self.clock.time()
+ parent_context = LoggingContext.current_context()
+ self._logging_context = LoggingContext(
+ "Measure[%s]" % (self.name,), parent_context
+ )
+ self._logging_context.__enter__()
in_flight.register((self.name,), self._update_in_flight)
def __exit__(self, exc_type, exc_val, exc_tb):
- if isinstance(exc_type, Exception) or not self.start_context:
- return
-
- in_flight.unregister((self.name,), self._update_in_flight)
+ if not self._logging_context:
+ raise RuntimeError("Measure() block exited without being entered")
duration = self.clock.time() - self.start
+ usage = self._logging_context.get_resource_usage()
- block_counter.labels(self.name).inc()
- block_timer.labels(self.name).inc(duration)
-
- context = LoggingContext.current_context()
-
- if context != self.start_context:
- logger.warn(
- "Context has unexpectedly changed from '%s' to '%s'. (%r)",
- self.start_context, context, self.name
- )
- return
-
- if not context:
- logger.warn("Expected context. (%r)", self.name)
- return
+ in_flight.unregister((self.name,), self._update_in_flight)
+ self._logging_context.__exit__(exc_type, exc_val, exc_tb)
- current = context.get_resource_usage()
- usage = current - self.start_usage
try:
+ block_counter.labels(self.name).inc()
+ block_timer.labels(self.name).inc(duration)
block_ru_utime.labels(self.name).inc(usage.ru_utime)
block_ru_stime.labels(self.name).inc(usage.ru_stime)
block_db_txn_count.labels(self.name).inc(usage.db_txn_count)
block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
except ValueError:
- logger.warn(
- "Failed to save metrics! OLD: %r, NEW: %r",
- self.start_usage, current
- )
-
- if self.created_context:
- self.start_context.__exit__(exc_type, exc_val, exc_tb)
+ logger.warning("Failed to save metrics! Usage: %s", usage)
def _update_in_flight(self, metrics):
"""Gets called when processing in flight metrics
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 4288312b8a..bb62db4637 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -14,12 +14,13 @@
# limitations under the License.
import importlib
+import importlib.util
from synapse.config._base import ConfigError
def load_module(provider):
- """ Loads a module with its config
+ """ Loads a synapse module with its config
Take a dict with keys 'module' (the module name) and 'config'
(the config dict).
@@ -28,15 +29,30 @@ def load_module(provider):
"""
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
- module, clz = provider['module'].rsplit(".", 1)
+ module, clz = provider["module"].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
try:
- provider_config = provider_class.parse_config(provider["config"])
+ provider_config = provider_class.parse_config(provider.get("config"))
except Exception as e:
- raise ConfigError(
- "Failed to parse config for %r: %r" % (provider['module'], e)
- )
+ raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
return provider_class, provider_config
+
+
+def load_python_module(location: str):
+ """Load a python module, and return a reference to its global namespace
+
+ Args:
+ location (str): path to the module
+
+ Returns:
+ python module object
+ """
+ spec = importlib.util.spec_from_file_location(location, location)
+ if spec is None:
+ raise Exception("Unable to load module at %s" % (location,))
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod) # type: ignore
+ return mod
diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
index a6c30e5265..c8bcbe297a 100644
--- a/synapse/util/msisdn.py
+++ b/synapse/util/msisdn.py
@@ -36,6 +36,6 @@ def phone_number_to_msisdn(country, number):
phoneNumber = phonenumbers.parse(number, country)
except phonenumbers.NumberParseException:
raise SynapseError(400, "Unable to parse phone number")
- return phonenumbers.format_number(
- phoneNumber, phonenumbers.PhoneNumberFormat.E164
- )[1:]
+ return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[
+ 1:
+ ]
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
new file mode 100644
index 0000000000..3925927f9f
--- /dev/null
+++ b/synapse/util/patch_inline_callbacks.py
@@ -0,0 +1,219 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import functools
+import sys
+from typing import Any, Callable, List
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+# Tracks if we've already patched inlineCallbacks
+_already_patched = False
+
+
+def do_patch():
+ """
+ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ global _already_patched
+
+ orig_inline_callbacks = defer.inlineCallbacks
+ if _already_patched:
+ return
+
+ def new_inline_callbacks(f):
+ @functools.wraps(f)
+ def wrapped(*args, **kwargs):
+ start_context = LoggingContext.current_context()
+ changes = [] # type: List[str]
+ orig = orig_inline_callbacks(_check_yield_points(f, changes))
+
+ try:
+ res = orig(*args, **kwargs)
+ except Exception:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "%s changed context from %s to %s on exception" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ raise
+
+ if not isinstance(res, Deferred) or res.called:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "Completed %s changed context from %s to %s" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ # print the error to stderr because otherwise all we
+ # see in travis-ci is the 500 error
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return res
+
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ err = (
+ "%s returned incomplete deferred in non-sentinel context "
+ "%s (start was %s)"
+ ) % (f, LoggingContext.current_context(), start_context)
+ print(err, file=sys.stderr)
+ raise Exception(err)
+
+ def check_ctx(r):
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+ err = "%s completion of %s changed context from %s to %s" % (
+ "Failure" if isinstance(r, Failure) else "Success",
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return r
+
+ res.addBoth(check_ctx)
+ return res
+
+ return wrapped
+
+ defer.inlineCallbacks = new_inline_callbacks
+ _already_patched = True
+
+
+def _check_yield_points(f: Callable, changes: List[str]):
+ """Wraps a generator that is about to be passed to defer.inlineCallbacks
+ checking that after every yield the log contexts are correct.
+
+ It's perfectly valid for log contexts to change within a function, e.g. due
+ to new Measure blocks, so such changes are added to the given `changes`
+ list instead of triggering an exception.
+
+ Args:
+ f: generator function to wrap
+ changes: A list of strings detailing how the contexts
+ changed within a function.
+
+ Returns:
+ function
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ @functools.wraps(f)
+ def check_yield_points_inner(*args, **kwargs):
+ gen = f(*args, **kwargs)
+
+ last_yield_line_no = gen.gi_frame.f_lineno
+ result = None # type: Any
+ while True:
+ expected_context = LoggingContext.current_context()
+
+ try:
+ isFailure = isinstance(result, Failure)
+ if isFailure:
+ d = result.throwExceptionIntoGenerator(gen)
+ else:
+ d = gen.send(result)
+ except (StopIteration, defer._DefGen_Return) as e:
+ if LoggingContext.current_context() != expected_context:
+ # This happens when the context is lost sometime *after* the
+ # final yield and returning. E.g. we forgot to yield on a
+ # function that returns a deferred.
+ #
+ # We don't raise here as it's perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "Function %r returned and changed context from %s to %s,"
+ " in %s between %d and end of func"
+ % (
+ f.__qualname__,
+ expected_context,
+ LoggingContext.current_context(),
+ f.__code__.co_filename,
+ last_yield_line_no,
+ )
+ )
+ changes.append(err)
+ return getattr(e, "value", None)
+
+ frame = gen.gi_frame
+
+ if isinstance(d, defer.Deferred) and not d.called:
+ # This happens if we yield on a deferred that doesn't follow
+ # the log context rules without wrapping in a `make_deferred_yieldable`.
+ # We raise here as this should never happen.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ err = (
+ "%s yielded with context %s rather than sentinel,"
+ " yielded on line %d in %s"
+ % (
+ frame.f_code.co_name,
+ LoggingContext.current_context(),
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ raise Exception(err)
+
+ try:
+ result = yield d
+ except Exception as e:
+ result = Failure(e)
+
+ if LoggingContext.current_context() != expected_context:
+
+ # This happens because the context is lost sometime *after* the
+ # previous yield and *after* the current yield. E.g. the
+ # deferred we waited on didn't follow the rules, or we forgot to
+ # yield on a function between the two yield points.
+ #
+ # We don't raise here as its perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "%s changed context from %s to %s, happened between lines %d and %d in %s"
+ % (
+ frame.f_code.co_name,
+ expected_context,
+ LoggingContext.current_context(),
+ last_yield_line_no,
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ changes.append(err)
+
+ last_yield_line_no = frame.f_lineno
+
+ return check_yield_points_inner
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index b146d137f4..5ca4521ce3 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -20,7 +20,7 @@ import logging
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
-from synapse.util.logcontext import (
+from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
@@ -36,9 +36,11 @@ class FederationRateLimiter(object):
clock (Clock)
config (FederationRateLimitConfig)
"""
- self.clock = clock
- self._config = config
- self.ratelimiters = {}
+
+ def new_limiter():
+ return _PerHostRatelimiter(clock=clock, config=config)
+
+ self.ratelimiters = collections.defaultdict(new_limiter)
def ratelimit(self, host):
"""Used to ratelimit an incoming request from given host
@@ -53,15 +55,9 @@ class FederationRateLimiter(object):
host (str): Origin of incoming request.
Returns:
- _PerHostRatelimiter
+ context manager which returns a deferred.
"""
- return self.ratelimiters.setdefault(
- host,
- _PerHostRatelimiter(
- clock=self.clock,
- config=self._config,
- )
- ).ratelimit()
+ return self.ratelimiters[host].ratelimit()
class _PerHostRatelimiter(object):
@@ -112,8 +108,7 @@ class _PerHostRatelimiter(object):
# remove any entries from request_times which aren't within the window
self.request_times[:] = [
- r for r in self.request_times
- if time_now - r < self.window_size
+ r for r in self.request_times if time_now - r < self.window_size
]
# reject the request if we already have too many queued up (either
@@ -121,15 +116,13 @@ class _PerHostRatelimiter(object):
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
- retry_after_ms=int(
- self.window_size / self.sleep_limit
- ),
+ retry_after_ms=int(self.window_size / self.sleep_limit)
)
self.request_times.append(time_now)
def queue_request():
- if len(self.current_processing) > self.concurrent_requests:
+ if len(self.current_processing) >= self.concurrent_requests:
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
logger.info(
@@ -143,22 +136,18 @@ class _PerHostRatelimiter(object):
logger.debug(
"Ratelimit [%s]: len(self.request_times)=%d",
- id(request_id), len(self.request_times),
+ id(request_id),
+ len(self.request_times),
)
if len(self.request_times) > self.sleep_limit:
- logger.debug(
- "Ratelimiter: sleeping request for %f sec", self.sleep_sec,
- )
+ logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
def on_wait_finished(_):
- logger.debug(
- "Ratelimit [%s]: Finished sleeping",
- id(request_id),
- )
+ logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
@@ -168,10 +157,7 @@ class _PerHostRatelimiter(object):
ret_defer = queue_request()
def on_start(r):
- logger.debug(
- "Ratelimit [%s]: Processing req",
- id(request_id),
- )
+ logger.debug("Ratelimit [%s]: Processing req", id(request_id))
self.current_processing.add(request_id)
return r
@@ -193,10 +179,7 @@ class _PerHostRatelimiter(object):
return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id):
- logger.debug(
- "Ratelimit [%s]: Processed req",
- id(request_id),
- )
+ logger.debug("Ratelimit [%s]: Processed req", id(request_id))
self.current_processing.discard(request_id)
try:
# start processing the next item on the queue.
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 1a77456498..af69587196 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -17,11 +17,20 @@ import random
from twisted.internet import defer
-import synapse.util.logcontext
+import synapse.logging.context
from synapse.api.errors import CodeMessageException
logger = logging.getLogger(__name__)
+# the intial backoff, after the first transaction fails
+MIN_RETRY_INTERVAL = 10 * 60 * 1000
+
+# how much we multiply the backoff by after each subsequent fail
+RETRY_MULTIPLIER = 5
+
+# a cap on the backoff. (Essentially none)
+MAX_RETRY_INTERVAL = 2 ** 62
+
class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination):
@@ -71,11 +80,13 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
# We aren't ready to retry that destination.
raise
"""
+ failure_ts = None
retry_last_ts, retry_interval = (0, 0)
retry_timings = yield store.get_destination_retry_timings(destination)
if retry_timings:
+ failure_ts = retry_timings["failure_ts"]
retry_last_ts, retry_interval = (
retry_timings["retry_last_ts"],
retry_timings["retry_interval"],
@@ -95,15 +106,14 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
# maximum backoff even though it might only have been down briefly
backoff_on_failure = not ignore_backoff
- defer.returnValue(
- RetryDestinationLimiter(
- destination,
- clock,
- store,
- retry_interval,
- backoff_on_failure=backoff_on_failure,
- **kwargs
- )
+ return RetryDestinationLimiter(
+ destination,
+ clock,
+ store,
+ failure_ts,
+ retry_interval,
+ backoff_on_failure=backoff_on_failure,
+ **kwargs
)
@@ -113,10 +123,8 @@ class RetryDestinationLimiter(object):
destination,
clock,
store,
+ failure_ts,
retry_interval,
- min_retry_interval=10 * 60 * 1000,
- max_retry_interval=24 * 60 * 60 * 1000,
- multiplier_retry_interval=5,
backoff_on_404=False,
backoff_on_failure=True,
):
@@ -129,15 +137,11 @@ class RetryDestinationLimiter(object):
destination (str)
clock (Clock)
store (DataStore)
+ failure_ts (int|None): when this destination started failing (in ms since
+ the epoch), or zero if the last request was successful
retry_interval (int): The next retry interval taken from the
database in milliseconds, or zero if the last request was
successful.
- min_retry_interval (int): The minimum retry interval to use after
- a failed request, in milliseconds.
- max_retry_interval (int): The maximum retry interval to use after
- a failed request, in milliseconds.
- multiplier_retry_interval (int): The multiplier to use to increase
- the retry interval after a failed request.
backoff_on_404 (bool): Back off if we get a 404
backoff_on_failure (bool): set to False if we should not increase the
@@ -147,10 +151,8 @@ class RetryDestinationLimiter(object):
self.store = store
self.destination = destination
+ self.failure_ts = failure_ts
self.retry_interval = retry_interval
- self.min_retry_interval = min_retry_interval
- self.max_retry_interval = max_retry_interval
- self.multiplier_retry_interval = multiplier_retry_interval
self.backoff_on_404 = backoff_on_404
self.backoff_on_failure = backoff_on_failure
@@ -191,6 +193,7 @@ class RetryDestinationLimiter(object):
logger.debug(
"Connection to %s was successful; clearing backoff", self.destination
)
+ self.failure_ts = None
retry_last_ts = 0
self.retry_interval = 0
elif not self.backoff_on_failure:
@@ -198,13 +201,14 @@ class RetryDestinationLimiter(object):
else:
# We couldn't connect.
if self.retry_interval:
- self.retry_interval *= self.multiplier_retry_interval
- self.retry_interval *= int(random.uniform(0.8, 1.4))
+ self.retry_interval = int(
+ self.retry_interval * RETRY_MULTIPLIER * random.uniform(0.8, 1.4)
+ )
- if self.retry_interval >= self.max_retry_interval:
- self.retry_interval = self.max_retry_interval
+ if self.retry_interval >= MAX_RETRY_INTERVAL:
+ self.retry_interval = MAX_RETRY_INTERVAL
else:
- self.retry_interval = self.min_retry_interval
+ self.retry_interval = MIN_RETRY_INTERVAL
logger.info(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
@@ -215,14 +219,20 @@ class RetryDestinationLimiter(object):
)
retry_last_ts = int(self.clock.time_msec())
+ if self.failure_ts is None:
+ self.failure_ts = retry_last_ts
+
@defer.inlineCallbacks
def store_retry_timings():
try:
yield self.store.set_destination_retry_timings(
- self.destination, retry_last_ts, self.retry_interval
+ self.destination,
+ self.failure_ts,
+ retry_last_ts,
+ self.retry_interval,
)
except Exception:
logger.exception("Failed to store destination_retry_timings")
# we deliberately do this in the background.
- synapse.util.logcontext.run_in_background(store_retry_timings)
+ synapse.logging.context.run_in_background(store_retry_timings)
diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py
index 6c0f2bb0cf..207cd17c2a 100644
--- a/synapse/util/rlimit.py
+++ b/synapse/util/rlimit.py
@@ -33,4 +33,4 @@ def change_resource_limit(soft_file_no):
resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
)
except (ValueError, resource.error) as e:
- logger.warn("Failed to set file or core limit: %s", e)
+ logger.warning("Failed to set file or core limit: %s", e)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 5fb18ee1f8..2c0dcb5208 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -24,26 +24,25 @@ from six.moves import range
from synapse.api.errors import Codes, SynapseError
-_string_with_symbols = (
- string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
-)
+_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
+
+# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
+# Note: The : character is allowed here for older clients, but will be removed in a
+# future release. Context: https://github.com/matrix-org/synapse/issues/6766
+client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$")
# random_string and random_string_with_symbols are used for a range of things,
# some cryptographically important, some less so. We use SystemRandom to make sure
# we get cryptographically-secure randoms.
rand = random.SystemRandom()
-client_secret_regex = re.compile(r"^[0-9a-zA-Z.=_-]+$")
-
def random_string(length):
- return ''.join(rand.choice(string.ascii_letters) for _ in range(length))
+ return "".join(rand.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length):
- return ''.join(
- rand.choice(_string_with_symbols) for _ in range(length)
- )
+ return "".join(rand.choice(_string_with_symbols) for _ in range(length))
def is_ascii(s):
@@ -51,7 +50,7 @@ def is_ascii(s):
if PY3:
if isinstance(s, bytes):
try:
- s.decode('ascii').encode('ascii')
+ s.decode("ascii").encode("ascii")
except UnicodeDecodeError:
return False
except UnicodeEncodeError:
@@ -110,13 +109,13 @@ def exception_to_unicode(e):
# and instead look at what is in the args member.
if len(e.args) == 0:
- return u""
+ return ""
elif len(e.args) > 1:
return six.text_type(repr(e.args))
msg = e.args[0]
if isinstance(msg, bytes):
- return msg.decode('utf-8', errors='replace')
+ return msg.decode("utf-8", errors="replace")
else:
return msg
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 4cc7d27ce5..34ce7cac16 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -36,25 +36,26 @@ def check_3pid_allowed(hs, medium, address):
if hs.config.check_is_for_allowed_local_3pids:
data = yield hs.get_simple_http_client().get_json(
- "https://%s%s" % (
+ "https://%s%s"
+ % (
hs.config.check_is_for_allowed_local_3pids,
- "/_matrix/identity/api/v1/internal-info"
+ "/_matrix/identity/api/v1/internal-info",
),
- {'medium': medium, 'address': address}
+ {"medium": medium, "address": address},
)
# Check for invalid response
- if 'hs' not in data and 'shadow_hs' not in data:
+ if "hs" not in data and "shadow_hs" not in data:
defer.returnValue(False)
# Check if this user is intended to register for this homeserver
if (
- data.get('hs') != hs.config.server_name
- and data.get('shadow_hs') != hs.config.server_name
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
):
defer.returnValue(False)
- if data.get('requires_invite', False) and not data.get('invited', False):
+ if data.get("requires_invite", False) and not data.get("invited", False):
# Requires an invite but hasn't been invited
defer.returnValue(False)
@@ -64,11 +65,13 @@ def check_3pid_allowed(hs, medium, address):
for constraint in hs.config.allowed_local_3pids:
logger.debug(
"Checking 3PID %s (%s) against %s (%s)",
- address, medium, constraint['pattern'], constraint['medium'],
+ address,
+ medium,
+ constraint["pattern"],
+ constraint["medium"],
)
- if (
- medium == constraint['medium'] and
- re.match(constraint['pattern'], address)
+ if medium == constraint["medium"] and re.match(
+ constraint["pattern"], address
):
defer.returnValue(True)
else:
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 3baba3225a..ab7d03af3a 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -22,63 +22,87 @@ logger = logging.getLogger(__name__)
def get_version_string(module):
+ """Given a module calculate a git-aware version string for it.
+
+ If called on a module not in a git checkout will return `__verison__`.
+
+ Args:
+ module (module)
+
+ Returns:
+ str
+ """
+
+ cached_version = getattr(module, "_synapse_version_string_cache", None)
+ if cached_version:
+ return cached_version
+
+ version_string = module.__version__
+
try:
- null = open(os.devnull, 'w')
+ null = open(os.devnull, "w")
cwd = os.path.dirname(os.path.abspath(module.__file__))
+
try:
- git_branch = subprocess.check_output(
- ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii')
+ git_branch = (
+ subprocess.check_output(
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ )
git_branch = "b=" + git_branch
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ # FileNotFoundError can arise when git is not installed
git_branch = ""
try:
- git_tag = subprocess.check_output(
- ['git', 'describe', '--exact-match'],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii')
+ git_tag = (
+ subprocess.check_output(
+ ["git", "describe", "--exact-match"], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ )
git_tag = "t=" + git_tag
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_tag = ""
try:
- git_commit = subprocess.check_output(
- ['git', 'rev-parse', '--short', 'HEAD'],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii')
- except subprocess.CalledProcessError:
+ git_commit = (
+ subprocess.check_output(
+ ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ )
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_commit = ""
try:
dirty_string = "-this_is_a_dirty_checkout"
- is_dirty = subprocess.check_output(
- ['git', 'describe', '--dirty=' + dirty_string],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii').endswith(dirty_string)
+ is_dirty = (
+ subprocess.check_output(
+ ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ .endswith(dirty_string)
+ )
git_dirty = "dirty" if is_dirty else ""
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_dirty = ""
if git_branch or git_tag or git_commit or git_dirty:
git_version = ",".join(
- s for s in
- (git_branch, git_tag, git_commit, git_dirty,)
- if s
+ s for s in (git_branch, git_tag, git_commit, git_dirty) if s
)
- return (
- "%s (%s)" % (
- module.__version__, git_version,
- )
- )
+ version_string = "%s (%s)" % (module.__version__, git_version)
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- return module.__version__
+ module._synapse_version_string_cache = version_string
+
+ return version_string
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 7a9e45aca9..9bf6a44f75 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -69,9 +69,7 @@ class WheelTimer(object):
# Add empty entries between the end of the current list and when we want
# to insert. This ensures there are no gaps.
- self.entries.extend(
- _Entry(key) for key in range(last_key, then_key + 1)
- )
+ self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1))
self.entries[-1].queue.append(obj)
|