diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f1c46836b1..804dbca443 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,12 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import collections
import logging
from contextlib import contextmanager
+from typing import Dict, Sequence, Set, Union
from six.moves import range
+import attr
+
from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.python import failure
@@ -213,7 +217,9 @@ class Linearizer(object):
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
- self.key_to_defer = {}
+ self.key_to_defer = (
+ {}
+ ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@@ -340,10 +346,10 @@ class ReadWriteLock(object):
def __init__(self):
# Latest readers queued
- self.key_to_current_readers = {}
+ self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued
- self.key_to_current_writer = {}
+ self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks
def read(self, key):
@@ -479,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred.addCallbacks(success_cb, failure_cb)
return new_d
+
+
+@attr.s(slots=True, frozen=True)
+class DoneAwaitable(object):
+ """Simple awaitable that returns the provided value.
+ """
+
+ value = attr.ib()
+
+ def __await__(self):
+ return self
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ raise StopIteration(self.value)
+
+
+def maybe_awaitable(value):
+ """Convert a value to an awaitable if not already an awaitable.
+ """
+
+ if hasattr(value, "__await__"):
+ return value
+
+ return DoneAwaitable(value)
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index b50e3503f0..43fd65d693 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -16,6 +16,7 @@
import logging
import os
+from typing import Dict
import six
from six.moves import intern
@@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name):
caches_by_name = {}
-collectors_by_name = {}
+collectors_by_name = {} # type: Dict
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 43f66ec4be..5ac2530a6a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,10 +18,12 @@ import inspect
import logging
import threading
from collections import namedtuple
+from typing import Any, cast
from six import itervalues
from prometheus_client import Gauge
+from typing_extensions import Protocol
from twisted.internet import defer
@@ -37,6 +39,18 @@ from . import register_cache
logger = logging.getLogger(__name__)
+class _CachedFunction(Protocol):
+ invalidate = None # type: Any
+ invalidate_all = None # type: Any
+ invalidate_many = None # type: Any
+ prefill = None # type: Any
+ cache = None # type: Any
+ num_args = None # type: Any
+
+ def __name__(self):
+ ...
+
+
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
@@ -245,7 +259,9 @@ class Cache(object):
class _CacheDescriptorBase(object):
- def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+ def __init__(
+ self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
+ ):
self.orig = orig
if inlineCallbacks:
@@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig)
- def wrapped(*args, **kwargs):
+ def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(observer)
+ wrapped = cast(_CachedFunction, _wrapped)
+
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 9a72218d85..2ea4e4e911 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,3 +1,5 @@
+from typing import Dict
+
from six import itervalues
SENTINEL = object()
@@ -12,7 +14,7 @@ class TreeCache(object):
def __init__(self):
self.size = 0
- self.root = {}
+ self.root = {} # type: Dict
def __setitem__(self, key, value):
return self.set(key, value)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 0910930c21..4b1bcdf23c 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -60,12 +60,14 @@ in_flight = InFlightGauge(
)
-def measure_func(name):
+def measure_func(name=None):
def wrapper(func):
+ block_name = func.__name__ if name is None else name
+
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
- with Measure(self.clock, name):
+ with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 7ff7eb1e4d..2705cbe5f8 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -54,5 +54,5 @@ def load_python_module(location: str):
if spec is None:
raise Exception("Unable to load module at %s" % (location,))
mod = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(mod)
+ spec.loader.exec_module(mod) # type: ignore
return mod
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
new file mode 100644
index 0000000000..3925927f9f
--- /dev/null
+++ b/synapse/util/patch_inline_callbacks.py
@@ -0,0 +1,219 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import functools
+import sys
+from typing import Any, Callable, List
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+# Tracks if we've already patched inlineCallbacks
+_already_patched = False
+
+
+def do_patch():
+ """
+ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ global _already_patched
+
+ orig_inline_callbacks = defer.inlineCallbacks
+ if _already_patched:
+ return
+
+ def new_inline_callbacks(f):
+ @functools.wraps(f)
+ def wrapped(*args, **kwargs):
+ start_context = LoggingContext.current_context()
+ changes = [] # type: List[str]
+ orig = orig_inline_callbacks(_check_yield_points(f, changes))
+
+ try:
+ res = orig(*args, **kwargs)
+ except Exception:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "%s changed context from %s to %s on exception" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ raise
+
+ if not isinstance(res, Deferred) or res.called:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "Completed %s changed context from %s to %s" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ # print the error to stderr because otherwise all we
+ # see in travis-ci is the 500 error
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return res
+
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ err = (
+ "%s returned incomplete deferred in non-sentinel context "
+ "%s (start was %s)"
+ ) % (f, LoggingContext.current_context(), start_context)
+ print(err, file=sys.stderr)
+ raise Exception(err)
+
+ def check_ctx(r):
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+ err = "%s completion of %s changed context from %s to %s" % (
+ "Failure" if isinstance(r, Failure) else "Success",
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return r
+
+ res.addBoth(check_ctx)
+ return res
+
+ return wrapped
+
+ defer.inlineCallbacks = new_inline_callbacks
+ _already_patched = True
+
+
+def _check_yield_points(f: Callable, changes: List[str]):
+ """Wraps a generator that is about to be passed to defer.inlineCallbacks
+ checking that after every yield the log contexts are correct.
+
+ It's perfectly valid for log contexts to change within a function, e.g. due
+ to new Measure blocks, so such changes are added to the given `changes`
+ list instead of triggering an exception.
+
+ Args:
+ f: generator function to wrap
+ changes: A list of strings detailing how the contexts
+ changed within a function.
+
+ Returns:
+ function
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ @functools.wraps(f)
+ def check_yield_points_inner(*args, **kwargs):
+ gen = f(*args, **kwargs)
+
+ last_yield_line_no = gen.gi_frame.f_lineno
+ result = None # type: Any
+ while True:
+ expected_context = LoggingContext.current_context()
+
+ try:
+ isFailure = isinstance(result, Failure)
+ if isFailure:
+ d = result.throwExceptionIntoGenerator(gen)
+ else:
+ d = gen.send(result)
+ except (StopIteration, defer._DefGen_Return) as e:
+ if LoggingContext.current_context() != expected_context:
+ # This happens when the context is lost sometime *after* the
+ # final yield and returning. E.g. we forgot to yield on a
+ # function that returns a deferred.
+ #
+ # We don't raise here as it's perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "Function %r returned and changed context from %s to %s,"
+ " in %s between %d and end of func"
+ % (
+ f.__qualname__,
+ expected_context,
+ LoggingContext.current_context(),
+ f.__code__.co_filename,
+ last_yield_line_no,
+ )
+ )
+ changes.append(err)
+ return getattr(e, "value", None)
+
+ frame = gen.gi_frame
+
+ if isinstance(d, defer.Deferred) and not d.called:
+ # This happens if we yield on a deferred that doesn't follow
+ # the log context rules without wrapping in a `make_deferred_yieldable`.
+ # We raise here as this should never happen.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ err = (
+ "%s yielded with context %s rather than sentinel,"
+ " yielded on line %d in %s"
+ % (
+ frame.f_code.co_name,
+ LoggingContext.current_context(),
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ raise Exception(err)
+
+ try:
+ result = yield d
+ except Exception as e:
+ result = Failure(e)
+
+ if LoggingContext.current_context() != expected_context:
+
+ # This happens because the context is lost sometime *after* the
+ # previous yield and *after* the current yield. E.g. the
+ # deferred we waited on didn't follow the rules, or we forgot to
+ # yield on a function between the two yield points.
+ #
+ # We don't raise here as its perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "%s changed context from %s to %s, happened between lines %d and %d in %s"
+ % (
+ frame.f_code.co_name,
+ expected_context,
+ LoggingContext.current_context(),
+ last_yield_line_no,
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ changes.append(err)
+
+ last_yield_line_no = frame.f_lineno
+
+ return check_yield_points_inner
|