diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f1c46836b1..804dbca443 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,12 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import collections
import logging
from contextlib import contextmanager
+from typing import Dict, Sequence, Set, Union
from six.moves import range
+import attr
+
from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.python import failure
@@ -213,7 +217,9 @@ class Linearizer(object):
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
- self.key_to_defer = {}
+ self.key_to_defer = (
+ {}
+ ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@@ -340,10 +346,10 @@ class ReadWriteLock(object):
def __init__(self):
# Latest readers queued
- self.key_to_current_readers = {}
+ self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued
- self.key_to_current_writer = {}
+ self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks
def read(self, key):
@@ -479,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred.addCallbacks(success_cb, failure_cb)
return new_d
+
+
+@attr.s(slots=True, frozen=True)
+class DoneAwaitable(object):
+ """Simple awaitable that returns the provided value.
+ """
+
+ value = attr.ib()
+
+ def __await__(self):
+ return self
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ raise StopIteration(self.value)
+
+
+def maybe_awaitable(value):
+ """Convert a value to an awaitable if not already an awaitable.
+ """
+
+ if hasattr(value, "__await__"):
+ return value
+
+ return DoneAwaitable(value)
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index b50e3503f0..43fd65d693 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -16,6 +16,7 @@
import logging
import os
+from typing import Dict
import six
from six.moves import intern
@@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name):
caches_by_name = {}
-collectors_by_name = {}
+collectors_by_name = {} # type: Dict
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 43f66ec4be..5ac2530a6a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,10 +18,12 @@ import inspect
import logging
import threading
from collections import namedtuple
+from typing import Any, cast
from six import itervalues
from prometheus_client import Gauge
+from typing_extensions import Protocol
from twisted.internet import defer
@@ -37,6 +39,18 @@ from . import register_cache
logger = logging.getLogger(__name__)
+class _CachedFunction(Protocol):
+ invalidate = None # type: Any
+ invalidate_all = None # type: Any
+ invalidate_many = None # type: Any
+ prefill = None # type: Any
+ cache = None # type: Any
+ num_args = None # type: Any
+
+ def __name__(self):
+ ...
+
+
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
@@ -245,7 +259,9 @@ class Cache(object):
class _CacheDescriptorBase(object):
- def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+ def __init__(
+ self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
+ ):
self.orig = orig
if inlineCallbacks:
@@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig)
- def wrapped(*args, **kwargs):
+ def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(observer)
+ wrapped = cast(_CachedFunction, _wrapped)
+
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 9a72218d85..2ea4e4e911 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,3 +1,5 @@
+from typing import Dict
+
from six import itervalues
SENTINEL = object()
@@ -12,7 +14,7 @@ class TreeCache(object):
def __init__(self):
self.size = 0
- self.root = {}
+ self.root = {} # type: Dict
def __setitem__(self, key, value):
return self.set(key, value)
diff --git a/synapse/util/hash.py b/synapse/util/hash.py
new file mode 100644
index 0000000000..359168704e
--- /dev/null
+++ b/synapse/util/hash.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+
+import unpaddedbase64
+
+
+def sha256_and_url_safe_base64(input_text):
+ """SHA256 hash an input string, encode the digest as url-safe base64, and
+ return
+
+ :param input_text: string to hash
+ :type input_text: str
+
+ :returns a sha256 hashed and url-safe base64 encoded digest
+ :rtype: str
+ """
+ digest = hashlib.sha256(input_text.encode()).digest()
+ return unpaddedbase64.encode_base64(digest, urlsafe=True)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 0910930c21..4b1bcdf23c 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -60,12 +60,14 @@ in_flight = InFlightGauge(
)
-def measure_func(name):
+def measure_func(name=None):
def wrapper(func):
+ block_name = func.__name__ if name is None else name
+
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
- with Measure(self.clock, name):
+ with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 522acd5aa8..2705cbe5f8 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -14,12 +14,13 @@
# limitations under the License.
import importlib
+import importlib.util
from synapse.config._base import ConfigError
def load_module(provider):
- """ Loads a module with its config
+ """ Loads a synapse module with its config
Take a dict with keys 'module' (the module name) and 'config'
(the config dict).
@@ -38,3 +39,20 @@ def load_module(provider):
raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
return provider_class, provider_config
+
+
+def load_python_module(location: str):
+ """Load a python module, and return a reference to its global namespace
+
+ Args:
+ location (str): path to the module
+
+ Returns:
+ python module object
+ """
+ spec = importlib.util.spec_from_file_location(location, location)
+ if spec is None:
+ raise Exception("Unable to load module at %s" % (location,))
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod) # type: ignore
+ return mod
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
new file mode 100644
index 0000000000..3925927f9f
--- /dev/null
+++ b/synapse/util/patch_inline_callbacks.py
@@ -0,0 +1,219 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import functools
+import sys
+from typing import Any, Callable, List
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+# Tracks if we've already patched inlineCallbacks
+_already_patched = False
+
+
+def do_patch():
+ """
+ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ global _already_patched
+
+ orig_inline_callbacks = defer.inlineCallbacks
+ if _already_patched:
+ return
+
+ def new_inline_callbacks(f):
+ @functools.wraps(f)
+ def wrapped(*args, **kwargs):
+ start_context = LoggingContext.current_context()
+ changes = [] # type: List[str]
+ orig = orig_inline_callbacks(_check_yield_points(f, changes))
+
+ try:
+ res = orig(*args, **kwargs)
+ except Exception:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "%s changed context from %s to %s on exception" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ raise
+
+ if not isinstance(res, Deferred) or res.called:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "Completed %s changed context from %s to %s" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ # print the error to stderr because otherwise all we
+ # see in travis-ci is the 500 error
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return res
+
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ err = (
+ "%s returned incomplete deferred in non-sentinel context "
+ "%s (start was %s)"
+ ) % (f, LoggingContext.current_context(), start_context)
+ print(err, file=sys.stderr)
+ raise Exception(err)
+
+ def check_ctx(r):
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+ err = "%s completion of %s changed context from %s to %s" % (
+ "Failure" if isinstance(r, Failure) else "Success",
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return r
+
+ res.addBoth(check_ctx)
+ return res
+
+ return wrapped
+
+ defer.inlineCallbacks = new_inline_callbacks
+ _already_patched = True
+
+
+def _check_yield_points(f: Callable, changes: List[str]):
+ """Wraps a generator that is about to be passed to defer.inlineCallbacks
+ checking that after every yield the log contexts are correct.
+
+ It's perfectly valid for log contexts to change within a function, e.g. due
+ to new Measure blocks, so such changes are added to the given `changes`
+ list instead of triggering an exception.
+
+ Args:
+ f: generator function to wrap
+ changes: A list of strings detailing how the contexts
+ changed within a function.
+
+ Returns:
+ function
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ @functools.wraps(f)
+ def check_yield_points_inner(*args, **kwargs):
+ gen = f(*args, **kwargs)
+
+ last_yield_line_no = gen.gi_frame.f_lineno
+ result = None # type: Any
+ while True:
+ expected_context = LoggingContext.current_context()
+
+ try:
+ isFailure = isinstance(result, Failure)
+ if isFailure:
+ d = result.throwExceptionIntoGenerator(gen)
+ else:
+ d = gen.send(result)
+ except (StopIteration, defer._DefGen_Return) as e:
+ if LoggingContext.current_context() != expected_context:
+ # This happens when the context is lost sometime *after* the
+ # final yield and returning. E.g. we forgot to yield on a
+ # function that returns a deferred.
+ #
+ # We don't raise here as it's perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "Function %r returned and changed context from %s to %s,"
+ " in %s between %d and end of func"
+ % (
+ f.__qualname__,
+ expected_context,
+ LoggingContext.current_context(),
+ f.__code__.co_filename,
+ last_yield_line_no,
+ )
+ )
+ changes.append(err)
+ return getattr(e, "value", None)
+
+ frame = gen.gi_frame
+
+ if isinstance(d, defer.Deferred) and not d.called:
+ # This happens if we yield on a deferred that doesn't follow
+ # the log context rules without wrapping in a `make_deferred_yieldable`.
+ # We raise here as this should never happen.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ err = (
+ "%s yielded with context %s rather than sentinel,"
+ " yielded on line %d in %s"
+ % (
+ frame.f_code.co_name,
+ LoggingContext.current_context(),
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ raise Exception(err)
+
+ try:
+ result = yield d
+ except Exception as e:
+ result = Failure(e)
+
+ if LoggingContext.current_context() != expected_context:
+
+ # This happens because the context is lost sometime *after* the
+ # previous yield and *after* the current yield. E.g. the
+ # deferred we waited on didn't follow the rules, or we forgot to
+ # yield on a function between the two yield points.
+ #
+ # We don't raise here as its perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "%s changed context from %s to %s, happened between lines %d and %d in %s"
+ % (
+ frame.f_code.co_name,
+ expected_context,
+ LoggingContext.current_context(),
+ last_yield_line_no,
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ changes.append(err)
+
+ last_yield_line_no = frame.f_lineno
+
+ return check_yield_points_inner
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 0862b5ca5a..af69587196 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -22,6 +22,15 @@ from synapse.api.errors import CodeMessageException
logger = logging.getLogger(__name__)
+# the intial backoff, after the first transaction fails
+MIN_RETRY_INTERVAL = 10 * 60 * 1000
+
+# how much we multiply the backoff by after each subsequent fail
+RETRY_MULTIPLIER = 5
+
+# a cap on the backoff. (Essentially none)
+MAX_RETRY_INTERVAL = 2 ** 62
+
class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination):
@@ -71,11 +80,13 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
# We aren't ready to retry that destination.
raise
"""
+ failure_ts = None
retry_last_ts, retry_interval = (0, 0)
retry_timings = yield store.get_destination_retry_timings(destination)
if retry_timings:
+ failure_ts = retry_timings["failure_ts"]
retry_last_ts, retry_interval = (
retry_timings["retry_last_ts"],
retry_timings["retry_interval"],
@@ -99,6 +110,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
destination,
clock,
store,
+ failure_ts,
retry_interval,
backoff_on_failure=backoff_on_failure,
**kwargs
@@ -111,10 +123,8 @@ class RetryDestinationLimiter(object):
destination,
clock,
store,
+ failure_ts,
retry_interval,
- min_retry_interval=10 * 60 * 1000,
- max_retry_interval=24 * 60 * 60 * 1000,
- multiplier_retry_interval=5,
backoff_on_404=False,
backoff_on_failure=True,
):
@@ -127,15 +137,11 @@ class RetryDestinationLimiter(object):
destination (str)
clock (Clock)
store (DataStore)
+ failure_ts (int|None): when this destination started failing (in ms since
+ the epoch), or zero if the last request was successful
retry_interval (int): The next retry interval taken from the
database in milliseconds, or zero if the last request was
successful.
- min_retry_interval (int): The minimum retry interval to use after
- a failed request, in milliseconds.
- max_retry_interval (int): The maximum retry interval to use after
- a failed request, in milliseconds.
- multiplier_retry_interval (int): The multiplier to use to increase
- the retry interval after a failed request.
backoff_on_404 (bool): Back off if we get a 404
backoff_on_failure (bool): set to False if we should not increase the
@@ -145,10 +151,8 @@ class RetryDestinationLimiter(object):
self.store = store
self.destination = destination
+ self.failure_ts = failure_ts
self.retry_interval = retry_interval
- self.min_retry_interval = min_retry_interval
- self.max_retry_interval = max_retry_interval
- self.multiplier_retry_interval = multiplier_retry_interval
self.backoff_on_404 = backoff_on_404
self.backoff_on_failure = backoff_on_failure
@@ -189,6 +193,7 @@ class RetryDestinationLimiter(object):
logger.debug(
"Connection to %s was successful; clearing backoff", self.destination
)
+ self.failure_ts = None
retry_last_ts = 0
self.retry_interval = 0
elif not self.backoff_on_failure:
@@ -196,13 +201,14 @@ class RetryDestinationLimiter(object):
else:
# We couldn't connect.
if self.retry_interval:
- self.retry_interval *= self.multiplier_retry_interval
- self.retry_interval *= int(random.uniform(0.8, 1.4))
+ self.retry_interval = int(
+ self.retry_interval * RETRY_MULTIPLIER * random.uniform(0.8, 1.4)
+ )
- if self.retry_interval >= self.max_retry_interval:
- self.retry_interval = self.max_retry_interval
+ if self.retry_interval >= MAX_RETRY_INTERVAL:
+ self.retry_interval = MAX_RETRY_INTERVAL
else:
- self.retry_interval = self.min_retry_interval
+ self.retry_interval = MIN_RETRY_INTERVAL
logger.info(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
@@ -213,11 +219,17 @@ class RetryDestinationLimiter(object):
)
retry_last_ts = int(self.clock.time_msec())
+ if self.failure_ts is None:
+ self.failure_ts = retry_last_ts
+
@defer.inlineCallbacks
def store_retry_timings():
try:
yield self.store.set_destination_retry_timings(
- self.destination, retry_last_ts, self.retry_interval
+ self.destination,
+ self.failure_ts,
+ retry_last_ts,
+ self.retry_interval,
)
except Exception:
logger.exception("Failed to store destination_retry_timings")
|