diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 25ea1bcc91..34c662c4db 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, Generic, Optional, TypeVar
+from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar
+
+import attr
from twisted.internet import defer
@@ -23,10 +25,36 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
-T = TypeVar("T")
+# the type of the key in the cache
+KV = TypeVar("KV")
+
+# the type of the result from the operation
+RV = TypeVar("RV")
+
+@attr.s(auto_attribs=True)
+class ResponseCacheContext(Generic[KV]):
+ """Information about a missed ResponseCache hit
-class ResponseCache(Generic[T]):
+ This object can be passed into the callback for additional feedback
+ """
+
+ cache_key: KV
+ """The cache key that caused the cache miss
+
+ This should be considered read-only.
+
+ TODO: in attrs 20.1, make it frozen with an on_setattr.
+ """
+
+ should_cache: bool = True
+ """Whether the result should be cached once the request completes.
+
+ This can be modified by the callback if it decides its result should not be cached.
+ """
+
+
+class ResponseCache(Generic[KV]):
"""
This caches a deferred response. Until the deferred completes it will be
returned from the cache. This means that if the client retries the request
@@ -35,8 +63,10 @@ class ResponseCache(Generic[T]):
"""
def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
- # Requests that haven't finished yet.
- self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
+ # This is poorly-named: it includes both complete and incomplete results.
+ # We keep complete results rather than switching to absolute values because
+ # that makes it easier to cache Failure results.
+ self.pending_result_cache = {} # type: Dict[KV, ObservableDeferred]
self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
@@ -50,16 +80,13 @@ class ResponseCache(Generic[T]):
def __len__(self) -> int:
return self.size()
- def get(self, key: T) -> Optional[defer.Deferred]:
+ def get(self, key: KV) -> Optional[defer.Deferred]:
"""Look up the given key.
- Can return either a new Deferred (which also doesn't follow the synapse
- logcontext rules), or, if the request has completed, the actual
- result. You will probably want to make_deferred_yieldable the result.
+ Returns a new Deferred (which also doesn't follow the synapse
+ logcontext rules). You will probably want to make_deferred_yieldable the result.
- If there is no entry for the key, returns None. It is worth noting that
- this means there is no way to distinguish a completed result of None
- from an absent cache entry.
+ If there is no entry for the key, returns None.
Args:
key: key to get/set in the cache
@@ -76,42 +103,56 @@ class ResponseCache(Generic[T]):
self._metrics.inc_misses()
return None
- def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
+ def _set(
+ self, context: ResponseCacheContext[KV], deferred: defer.Deferred
+ ) -> defer.Deferred:
"""Set the entry for the given key to the given deferred.
*deferred* should run its callbacks in the sentinel logcontext (ie,
you should wrap normal synapse deferreds with
synapse.logging.context.run_in_background).
- Can return either a new Deferred (which also doesn't follow the synapse
- logcontext rules), or, if *deferred* was already complete, the actual
- result. You will probably want to make_deferred_yieldable the result.
+ Returns a new Deferred (which also doesn't follow the synapse logcontext rules).
+ You will probably want to make_deferred_yieldable the result.
Args:
- key: key to get/set in the cache
+ context: Information about the cache miss
deferred: The deferred which resolves to the result.
Returns:
A new deferred which resolves to the actual result.
"""
result = ObservableDeferred(deferred, consumeErrors=True)
+ key = context.cache_key
self.pending_result_cache[key] = result
- def remove(r):
- if self.timeout_sec:
+ def on_complete(r):
+ # if this cache has a non-zero timeout, and the callback has not cleared
+ # the should_cache bit, we leave it in the cache for now and schedule
+ # its removal later.
+ if self.timeout_sec and context.should_cache:
self.clock.call_later(
self.timeout_sec, self.pending_result_cache.pop, key, None
)
else:
+ # otherwise, remove the result immediately.
self.pending_result_cache.pop(key, None)
return r
- result.addBoth(remove)
+ # make sure we do this *after* adding the entry to pending_result_cache,
+ # in case the result is already complete (in which case flipping the order would
+ # leave us with a stuck entry in the cache).
+ result.addBoth(on_complete)
return result.observe()
- def wrap(
- self, key: T, callback: Callable[..., Any], *args: Any, **kwargs: Any
- ) -> defer.Deferred:
+ async def wrap(
+ self,
+ key: KV,
+ callback: Callable[..., Awaitable[RV]],
+ *args: Any,
+ cache_context: bool = False,
+ **kwargs: Any,
+ ) -> RV:
"""Wrap together a *get* and *set* call, taking care of logcontexts
First looks up the key in the cache, and if it is present makes it
@@ -140,22 +181,28 @@ class ResponseCache(Generic[T]):
*args: positional parameters to pass to the callback, if it is used
+ cache_context: if set, the callback will be given a `cache_context` kw arg,
+ which will be a ResponseCacheContext object.
+
**kwargs: named parameters to pass to the callback, if it is used
Returns:
- Deferred which resolves to the result
+ The result of the callback (from the cache, or otherwise)
"""
result = self.get(key)
if not result:
logger.debug(
"[%s]: no cached result for [%s], calculating new one", self._name, key
)
+ context = ResponseCacheContext(cache_key=key)
+ if cache_context:
+ kwargs["cache_context"] = context
d = run_in_background(callback, *args, **kwargs)
- result = self.set(key, d)
+ result = self._set(context, d)
elif not isinstance(result, defer.Deferred) or result.called:
logger.info("[%s]: using completed cached result for [%s]", self._name, key)
else:
logger.info(
"[%s]: using incomplete cached result for [%s]", self._name, key
)
- return make_deferred_yieldable(result)
+ return await make_deferred_yieldable(result)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 6d14351bd2..45353d41c5 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -133,12 +133,17 @@ class Measure:
self.start = self.clock.time()
self._logging_context.__enter__()
in_flight.register((self.name,), self._update_in_flight)
+
+ logger.debug("Entering block %s", self.name)
+
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.start is None:
raise RuntimeError("Measure() block exited without being entered")
+ logger.debug("Exiting block %s", self.name)
+
duration = self.clock.time() - self.start
usage = self.get_resource_usage()
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index cbfbd097f9..5a638c6e9a 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -51,21 +51,26 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
# Load the module config. If None, pass an empty dictionary instead
module_config = provider.get("config") or {}
- try:
- provider_config = provider_class.parse_config(module_config)
- except jsonschema.ValidationError as e:
- raise json_error_to_config_error(e, itertools.chain(config_path, ("config",)))
- except ConfigError as e:
- raise _wrap_config_error(
- "Failed to parse config for module %r" % (modulename,),
- prefix=itertools.chain(config_path, ("config",)),
- e=e,
- )
- except Exception as e:
- raise ConfigError(
- "Failed to parse config for module %r" % (modulename,),
- path=itertools.chain(config_path, ("config",)),
- ) from e
+ if hasattr(provider_class, "parse_config"):
+ try:
+ provider_config = provider_class.parse_config(module_config)
+ except jsonschema.ValidationError as e:
+ raise json_error_to_config_error(
+ e, itertools.chain(config_path, ("config",))
+ )
+ except ConfigError as e:
+ raise _wrap_config_error(
+ "Failed to parse config for module %r" % (modulename,),
+ prefix=itertools.chain(config_path, ("config",)),
+ e=e,
+ )
+ except Exception as e:
+ raise ConfigError(
+ "Failed to parse config for module %r" % (modulename,),
+ path=itertools.chain(config_path, ("config",)),
+ ) from e
+ else:
+ provider_config = module_config
return provider_class, provider_config
|