diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index a90f08dd4c..7be9d5f113 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -15,7 +15,7 @@
import json
import logging
import typing
-from typing import Any, Callable, Dict, Generator, Optional
+from typing import Any, Callable, Dict, Generator, Optional, Sequence
import attr
from frozendict import frozendict
@@ -193,3 +193,15 @@ def log_failure(
# Version string with git info. Computed here once so that we don't invoke git multiple
# times.
SYNAPSE_VERSION = get_distribution_version_string("matrix-synapse", __file__)
+
+
+class ExceptionBundle(Exception):
+ # A poor stand-in for something like Python 3.11's ExceptionGroup.
+ # (A backport called `exceptiongroup` exists but seems overkill: we just want a
+ # container type here.)
+ def __init__(self, message: str, exceptions: Sequence[Exception]):
+ parts = [message]
+ for e in exceptions:
+ parts.append(str(e))
+ super().__init__("\n - ".join(parts))
+ self.exceptions = exceptions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 7f1d41eb3c..d24c4f68c4 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -217,7 +217,8 @@ async def concurrently_execute(
limit: Maximum number of conccurent executions.
Returns:
- Deferred: Resolved when all function invocations have finished.
+ None, when all function invocations have finished. The return values
+ from those functions are discarded.
"""
it = iter(args)
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 42f6abb5e1..9387632d0d 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -20,9 +20,11 @@ from sys import intern
from typing import Any, Callable, Dict, List, Optional, Sized, TypeVar
import attr
+from prometheus_client import REGISTRY
from prometheus_client.core import Gauge
from synapse.config.cache import add_resizable_cache
+from synapse.util.metrics import DynamicCollectorRegistry
logger = logging.getLogger(__name__)
@@ -30,27 +32,62 @@ logger = logging.getLogger(__name__)
# Whether to track estimated memory usage of the LruCaches.
TRACK_MEMORY_USAGE = False
+# We track cache metrics in a special registry that lets us update the metrics
+# just before they are returned from the scrape endpoint.
+CACHE_METRIC_REGISTRY = DynamicCollectorRegistry()
caches_by_name: Dict[str, Sized] = {}
-collectors_by_name: Dict[str, "CacheMetric"] = {}
-cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
-cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
-cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name", "reason"])
-cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
-cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
+cache_size = Gauge(
+ "synapse_util_caches_cache_size", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+cache_hits = Gauge(
+ "synapse_util_caches_cache_hits", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+cache_evicted = Gauge(
+ "synapse_util_caches_cache_evicted_size",
+ "",
+ ["name", "reason"],
+ registry=CACHE_METRIC_REGISTRY,
+)
+cache_total = Gauge(
+ "synapse_util_caches_cache", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+cache_max_size = Gauge(
+ "synapse_util_caches_cache_max_size", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
cache_memory_usage = Gauge(
"synapse_util_caches_cache_size_bytes",
"Estimated memory usage of the caches",
["name"],
+ registry=CACHE_METRIC_REGISTRY,
)
-response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
-response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
+response_cache_size = Gauge(
+ "synapse_util_caches_response_cache_size",
+ "",
+ ["name"],
+ registry=CACHE_METRIC_REGISTRY,
+)
+response_cache_hits = Gauge(
+ "synapse_util_caches_response_cache_hits",
+ "",
+ ["name"],
+ registry=CACHE_METRIC_REGISTRY,
+)
response_cache_evicted = Gauge(
- "synapse_util_caches_response_cache:evicted_size", "", ["name", "reason"]
+ "synapse_util_caches_response_cache_evicted_size",
+ "",
+ ["name", "reason"],
+ registry=CACHE_METRIC_REGISTRY,
)
-response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
+response_cache_total = Gauge(
+ "synapse_util_caches_response_cache", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+
+
+# Register our custom cache metrics registry with the global registry
+REGISTRY.register(CACHE_METRIC_REGISTRY)
class EvictionReason(Enum):
@@ -160,7 +197,7 @@ def register_cache(
resize_callback: A function which can be called to resize the cache.
Returns:
- CacheMetric: an object which provides inc_{hits,misses,evictions} methods
+ an object which provides inc_{hits,misses,evictions} methods
"""
if resizable:
if not resize_callback:
@@ -170,7 +207,7 @@ def register_cache(
metric = CacheMetric(cache, cache_type, cache_name, collect_callback)
metric_name = "cache_%s_%s" % (cache_type, cache_name)
caches_by_name[cache_name] = cache
- collectors_by_name[metric_name] = metric
+ CACHE_METRIC_REGISTRY.register_hook(metric_name, metric.collect)
return metric
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1d6ec22191..bf7bd351e0 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -14,15 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import enum
import threading
from typing import (
Callable,
+ Collection,
+ Dict,
Generic,
- Iterable,
MutableMapping,
Optional,
+ Set,
Sized,
+ Tuple,
TypeVar,
Union,
cast,
@@ -31,7 +35,6 @@ from typing import (
from prometheus_client import Gauge
from twisted.internet import defer
-from twisted.python import failure
from twisted.python.failure import Failure
from synapse.util.async_helpers import ObservableDeferred
@@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache: Union[
- TreeCache, "MutableMapping[KT, CacheEntry]"
+ TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
] = cache_type()
def metrics_cb() -> None:
@@ -150,7 +153,7 @@ class DeferredCache(Generic[KT, VT]):
Args:
key:
callback: Gets called when the entry in the cache is invalidated
- update_metrics (bool): whether to update the cache hit rate metrics
+ update_metrics: whether to update the cache hit rate metrics
Returns:
A Deferred which completes with the result. Note that this may later fail
@@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
Raises:
KeyError if the key is not found in the cache
"""
- callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
- val.callbacks.update(callbacks)
+ val.add_invalidation_callback(key, callback)
if update_metrics:
m = self.cache.metrics
assert m # we always have a name, so should always have metrics
m.inc_hits()
- return val.deferred.observe()
+ return val.deferred(key)
+
+ callbacks = (callback,) if callback else ()
val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
else:
return defer.succeed(val2)
+ def get_bulk(
+ self,
+ keys: Collection[KT],
+ callback: Optional[Callable[[], None]] = None,
+ ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
+ """Bulk lookup of items in the cache.
+
+ Returns:
+ A 3-tuple of:
+ 1. a dict of key/value of items already cached;
+ 2. a deferred that resolves to a dict of key/value of items
+ we're already fetching; and
+ 3. a collection of keys that don't appear in the previous two.
+ """
+
+ # The cached results
+ cached = {}
+
+ # List of pending deferreds
+ pending = []
+
+ # Dict that gets filled out when the pending deferreds complete
+ pending_results = {}
+
+ # List of keys that aren't in either cache
+ missing = []
+
+ callbacks = (callback,) if callback else ()
+
+ for key in keys:
+ # Check if its in the main cache.
+ immediate_value = self.cache.get(
+ key,
+ _Sentinel.sentinel,
+ callbacks=callbacks,
+ )
+ if immediate_value is not _Sentinel.sentinel:
+ cached[key] = immediate_value
+ continue
+
+ # Check if its in the pending cache
+ pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+ if pending_value is not _Sentinel.sentinel:
+ pending_value.add_invalidation_callback(key, callback)
+
+ def completed_cb(value: VT, key: KT) -> VT:
+ pending_results[key] = value
+ return value
+
+ # Add a callback to fill out `pending_results` when that completes
+ d = pending_value.deferred(key).addCallback(completed_cb, key)
+ pending.append(d)
+ continue
+
+ # Not in either cache
+ missing.append(key)
+
+ # If we've got pending deferreds, squash them into a single one that
+ # returns `pending_results`.
+ pending_deferred = None
+ if pending:
+ pending_deferred = defer.gatherResults(
+ pending, consumeErrors=True
+ ).addCallback(lambda _: pending_results)
+
+ return (cached, pending_deferred, missing)
+
def get_immediate(
self, key: KT, default: T, update_metrics: bool = True
) -> Union[VT, T]:
@@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated
"""
- if not isinstance(value, defer.Deferred):
- raise TypeError("not a Deferred")
-
- callbacks = [callback] if callback else []
self.check_thread()
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry:
- existing_entry.invalidate()
+ self._pending_deferred_cache.pop(key, None)
# XXX: why don't we invalidate the entry in `self.cache` yet?
- # we can save a whole load of effort if the deferred is ready.
- if value.called:
- result = value.result
- if not isinstance(result, failure.Failure):
- self.cache.set(key, cast(VT, result), callbacks)
- return value
-
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later.
+ entry = CacheEntrySingle[KT, VT](value)
+ entry.add_invalidation_callback(key, callback)
+ self._pending_deferred_cache[key] = entry
+ deferred = entry.deferred(key).addCallbacks(
+ self._completed_callback,
+ self._error_callback,
+ callbackArgs=(entry, key),
+ errbackArgs=(entry, key),
+ )
- observable = ObservableDeferred(value, consumeErrors=True)
- observer = observable.observe()
- entry = CacheEntry(deferred=observable, callbacks=callbacks)
+ # we return a new Deferred which will be called before any subsequent observers.
+ return deferred
- self._pending_deferred_cache[key] = entry
+ def start_bulk_input(
+ self,
+ keys: Collection[KT],
+ callback: Optional[Callable[[], None]] = None,
+ ) -> "CacheMultipleEntries[KT, VT]":
+ """Bulk set API for use when fetching multiple keys at once from the DB.
- def compare_and_pop() -> bool:
- """Check if our entry is still the one in _pending_deferred_cache, and
- if so, pop it.
-
- Returns true if the entries matched.
- """
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry is entry:
- return True
-
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
- return False
-
- def cb(result: VT) -> None:
- if compare_and_pop():
- self.cache.set(key, result, entry.callbacks)
- else:
- # we're not going to put this entry into the cache, so need
- # to make sure that the invalidation callbacks are called.
- # That was probably done when _pending_deferred_cache was
- # updated, but it's possible that `set` was called without
- # `invalidate` being previously called, in which case it may
- # not have been. Either way, let's double-check now.
- entry.invalidate()
-
- def eb(_fail: Failure) -> None:
- compare_and_pop()
- entry.invalidate()
-
- # once the deferred completes, we can move the entry from the
- # _pending_deferred_cache to the real cache.
- #
- observer.addCallbacks(cb, eb)
+ Called *before* starting the fetch from the DB, and the caller *must*
+ call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
+ """
- # we return a new Deferred which will be called before any subsequent observers.
- return observable.observe()
+ entry = CacheMultipleEntries[KT, VT]()
+ entry.add_global_invalidation_callback(callback)
+
+ for key in keys:
+ self._pending_deferred_cache[key] = entry
+
+ return entry
+
+ def _completed_callback(
+ self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
+ ) -> VT:
+ """Called when a deferred is completed."""
+ # We check if the current entry matches the entry associated with the
+ # deferred. If they don't match then it got invalidated.
+ current_entry = self._pending_deferred_cache.pop(key, None)
+ if current_entry is not entry:
+ if current_entry:
+ self._pending_deferred_cache[key] = current_entry
+ return value
+
+ self.cache.set(key, value, entry.get_invalidation_callbacks(key))
+
+ return value
+
+ def _error_callback(
+ self,
+ failure: Failure,
+ entry: "CacheEntry[KT, VT]",
+ key: KT,
+ ) -> Failure:
+ """Called when a deferred errors."""
+
+ # We check if the current entry matches the entry associated with the
+ # deferred. If they don't match then it got invalidated.
+ current_entry = self._pending_deferred_cache.pop(key, None)
+ if current_entry is not entry:
+ if current_entry:
+ self._pending_deferred_cache[key] = current_entry
+ return failure
+
+ for cb in entry.get_invalidation_callbacks(key):
+ cb()
+
+ return failure
def prefill(
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
) -> None:
- callbacks = [callback] if callback else []
+ callbacks = (callback,) if callback else ()
self.cache.set(key, value, callbacks=callbacks)
+ self._pending_deferred_cache.pop(key, None)
def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries
@@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
- # _pending_deferred_cache, which will (a) stop it being returned
- # for future queries and (b) stop it being persisted as a proper entry
+ # _pending_deferred_cache, which will (a) stop it being returned for
+ # future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
-
- # run the invalidation callbacks now, rather than waiting for the
- # deferred to resolve.
if entry:
# _pending_deferred_cache.pop should either return a CacheEntry, or, in the
# case of a TreeCache, a dict of keys to cache entries. Either way calling
# iterate_tree_cache_entry on it will do the right thing.
- for entry in iterate_tree_cache_entry(entry):
- entry.invalidate()
+ for iter_entry in iterate_tree_cache_entry(entry):
+ for cb in iter_entry.get_invalidation_callbacks(key):
+ cb()
def invalidate_all(self) -> None:
self.check_thread()
self.cache.clear()
- for entry in self._pending_deferred_cache.values():
- entry.invalidate()
+ for key, entry in self._pending_deferred_cache.items():
+ for cb in entry.get_invalidation_callbacks(key):
+ cb()
+
self._pending_deferred_cache.clear()
-class CacheEntry:
- __slots__ = ["deferred", "callbacks", "invalidated"]
+class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
+ """Abstract class for entries in `DeferredCache[KT, VT]`"""
- def __init__(
- self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
- ):
- self.deferred = deferred
- self.callbacks = set(callbacks)
- self.invalidated = False
-
- def invalidate(self) -> None:
- if not self.invalidated:
- self.invalidated = True
- for callback in self.callbacks:
- callback()
- self.callbacks.clear()
+ @abc.abstractmethod
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ """Get a deferred that a caller can wait on to get the value at the
+ given key"""
+ ...
+
+ @abc.abstractmethod
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ """Add an invalidation callback"""
+ ...
+
+ @abc.abstractmethod
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ """Get all invalidation callbacks"""
+ ...
+
+
+class CacheEntrySingle(CacheEntry[KT, VT]):
+ """An implementation of `CacheEntry` wrapping a deferred that results in a
+ single cache entry.
+ """
+
+ __slots__ = ["_deferred", "_callbacks"]
+
+ def __init__(self, deferred: "defer.Deferred[VT]") -> None:
+ self._deferred = ObservableDeferred(deferred, consumeErrors=True)
+ self._callbacks: Set[Callable[[], None]] = set()
+
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ return self._deferred.observe()
+
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ if callback is None:
+ return
+
+ self._callbacks.add(callback)
+
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ return self._callbacks
+
+
+class CacheMultipleEntries(CacheEntry[KT, VT]):
+ """Cache entry that is used for bulk lookups and insertions."""
+
+ __slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
+
+ def __init__(self) -> None:
+ self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
+ self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
+ self._global_callbacks: Set[Callable[[], None]] = set()
+
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ if not self._deferred:
+ self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+ return self._deferred.observe().addCallback(lambda res: res.get(key))
+
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ if callback is None:
+ return
+
+ self._callbacks.setdefault(key, set()).add(callback)
+
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ return self._callbacks.get(key, set()) | self._global_callbacks
+
+ def add_global_invalidation_callback(
+ self, callback: Optional[Callable[[], None]]
+ ) -> None:
+ """Add a callback for when any keys get invalidated."""
+ if callback is None:
+ return
+
+ self._global_callbacks.add(callback)
+
+ def complete_bulk(
+ self,
+ cache: DeferredCache[KT, VT],
+ result: Dict[KT, VT],
+ ) -> None:
+ """Called when there is a result"""
+ for key, value in result.items():
+ cache._completed_callback(value, self, key)
+
+ if self._deferred:
+ self._deferred.callback(result)
+
+ def error_bulk(
+ self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
+ ) -> None:
+ """Called when bulk lookup failed."""
+ for key in keys:
+ cache._error_callback(failure, self, key)
+
+ if self._deferred:
+ self._deferred.errback(failure)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..72227359b9 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import enum
import functools
import inspect
import logging
@@ -25,6 +24,7 @@ from typing import (
Generic,
Hashable,
Iterable,
+ List,
Mapping,
Optional,
Sequence,
@@ -52,7 +52,7 @@ CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any])
-class _CachedFunction(Generic[F]):
+class CachedFunction(Generic[F]):
invalidate: Any = None
invalidate_all: Any = None
prefill: Any = None
@@ -73,8 +73,10 @@ class _CacheDescriptorBase:
num_args: Optional[int],
uncached_args: Optional[Collection[str]] = None,
cache_context: bool = False,
+ name: Optional[str] = None,
):
self.orig = orig
+ self.name = name or orig.__name__
arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
@@ -143,109 +145,6 @@ class _CacheDescriptorBase:
)
-class _LruCachedFunction(Generic[F]):
- cache: LruCache[CacheKey, Any]
- __call__: F
-
-
-def lru_cache(
- *, max_entries: int = 1000, cache_context: bool = False
-) -> Callable[[F], _LruCachedFunction[F]]:
- """A method decorator that applies a memoizing cache around the function.
-
- This is more-or-less a drop-in equivalent to functools.lru_cache, although note
- that the signature is slightly different.
-
- The main differences with functools.lru_cache are:
- (a) the size of the cache can be controlled via the cache_factor mechanism
- (b) the wrapped function can request a "cache_context" which provides a
- callback mechanism to indicate that the result is no longer valid
- (c) prometheus metrics are exposed automatically.
-
- The function should take zero or more arguments, which are used as the key for the
- cache. Single-argument functions use that argument as the cache key; otherwise the
- arguments are built into a tuple.
-
- Cached functions can be "chained" (i.e. a cached function can call other cached
- functions and get appropriately invalidated when they called caches are
- invalidated) by adding a special "cache_context" argument to the function
- and passing that as a kwarg to all caches called. For example:
-
- @lru_cache(cache_context=True)
- def foo(self, key, cache_context):
- r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
- r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
- return r1 + r2
-
- The wrapped function also has a 'cache' property which offers direct access to the
- underlying LruCache.
- """
-
- def func(orig: F) -> _LruCachedFunction[F]:
- desc = LruCacheDescriptor(
- orig,
- max_entries=max_entries,
- cache_context=cache_context,
- )
- return cast(_LruCachedFunction[F], desc)
-
- return func
-
-
-class LruCacheDescriptor(_CacheDescriptorBase):
- """Helper for @lru_cache"""
-
- class _Sentinel(enum.Enum):
- sentinel = object()
-
- def __init__(
- self,
- orig: Callable[..., Any],
- max_entries: int = 1000,
- cache_context: bool = False,
- ):
- super().__init__(
- orig, num_args=None, uncached_args=None, cache_context=cache_context
- )
- self.max_entries = max_entries
-
- def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
- cache: LruCache[CacheKey, Any] = LruCache(
- cache_name=self.orig.__name__,
- max_size=self.max_entries,
- )
-
- get_cache_key = self.cache_key_builder
- sentinel = LruCacheDescriptor._Sentinel.sentinel
-
- @functools.wraps(self.orig)
- def _wrapped(*args: Any, **kwargs: Any) -> Any:
- invalidate_callback = kwargs.pop("on_invalidate", None)
- callbacks = (invalidate_callback,) if invalidate_callback else ()
-
- cache_key = get_cache_key(args, kwargs)
-
- ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
- if ret != sentinel:
- return ret
-
- # Add our own `cache_context` to argument list if the wrapped function
- # has asked for one
- if self.add_cache_context:
- kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
-
- ret2 = self.orig(obj, *args, **kwargs)
- cache.set(cache_key, ret2, callbacks=callbacks)
-
- return ret2
-
- wrapped = cast(_CachedFunction, _wrapped)
- wrapped.cache = cache
- obj.__dict__[self.orig.__name__] = wrapped
-
- return wrapped
-
-
class DeferredCacheDescriptor(_CacheDescriptorBase):
"""A method decorator that applies a memoizing cache around the function.
@@ -301,12 +200,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
+ name: Optional[str] = None,
):
super().__init__(
orig,
num_args=num_args,
uncached_args=uncached_args,
cache_context=cache_context,
+ name=name,
)
if tree and self.num_args < 2:
@@ -321,7 +222,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
- name=self.orig.__name__,
+ name=self.name,
max_entries=self.max_entries,
tree=self.tree,
iterable=self.iterable,
@@ -358,7 +259,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(ret)
- wrapped = cast(_CachedFunction, _wrapped)
+ wrapped = cast(CachedFunction, _wrapped)
if self.num_args == 1:
assert not self.tree
@@ -372,7 +273,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
wrapped.cache = cache
wrapped.num_args = self.num_args
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -393,6 +294,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
+ name: Optional[str] = None,
):
"""
Args:
@@ -403,7 +305,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
- super().__init__(orig, num_args=num_args, uncached_args=None)
+ super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
self.list_name = list_name
@@ -425,6 +327,12 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args
+ if num_args != self.num_args:
+ raise TypeError(
+ "Number of args (%s) does not match underlying cache_method_name=%s (%s)."
+ % (self.num_args, self.cached_method_name, num_args)
+ )
+
@functools.wraps(self.orig)
def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
# If we're passed a cache_context then we'll want to call its
@@ -435,16 +343,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
- results = {}
-
- def update_results_dict(res: Any, arg: Hashable) -> None:
- results[arg] = res
-
- # list of deferreds to wait for
- cached_defers = []
-
- missing = set()
-
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
@@ -452,6 +350,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg
+ def cache_key_to_arg(key: tuple) -> Hashable:
+ return key
+
else:
keylist = list(keyargs)
@@ -459,58 +360,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keylist[self.list_pos] = arg
return tuple(keylist)
- for arg in list_args:
- try:
- res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
- if not res.called:
- res.addCallback(update_results_dict, arg)
- cached_defers.append(res)
- else:
- results[arg] = res.result
- except KeyError:
- missing.add(arg)
+ def cache_key_to_arg(key: tuple) -> Hashable:
+ return key[self.list_pos]
+
+ cache_keys = [arg_to_cache_key(arg) for arg in list_args]
+ immediate_results, pending_deferred, missing = cache.get_bulk(
+ cache_keys, callback=invalidate_callback
+ )
+
+ results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
+
+ cached_defers: List["defer.Deferred[Any]"] = []
+ if pending_deferred:
+
+ def update_results(r: Dict) -> None:
+ for k, v in r.items():
+ results[cache_key_to_arg(k)] = v
+
+ pending_deferred.addCallback(update_results)
+ cached_defers.append(pending_deferred)
if missing:
- # we need a deferred for each entry in the list,
- # which we put in the cache. Each deferred resolves with the
- # relevant result for that key.
- deferreds_map = {}
- for arg in missing:
- deferred: "defer.Deferred[Any]" = defer.Deferred()
- deferreds_map[arg] = deferred
- key = arg_to_cache_key(arg)
- cached_defers.append(
- cache.set(key, deferred, callback=invalidate_callback)
- )
+ cache_entry = cache.start_bulk_input(missing, invalidate_callback)
def complete_all(res: Dict[Hashable, Any]) -> None:
- # the wrapped function has completed. It returns a dict.
- # We can now update our own result map, and then resolve the
- # observable deferreds in the cache.
- for e, d1 in deferreds_map.items():
- val = res.get(e, None)
- # make sure we update the results map before running the
- # deferreds, because as soon as we run the last deferred, the
- # gatherResults() below will complete and return the result
- # dict to our caller.
- results[e] = val
- d1.callback(val)
+ missing_results = {}
+ for key in missing:
+ arg = cache_key_to_arg(key)
+ val = res.get(arg, None)
+
+ results[arg] = val
+ missing_results[key] = val
+
+ cache_entry.complete_bulk(cache, missing_results)
def errback_all(f: Failure) -> None:
- # the wrapped function has failed. Propagate the failure into
- # the cache, which will invalidate the entry, and cause the
- # relevant cached_deferreds to fail, which will propagate the
- # failure to our caller.
- for d1 in deferreds_map.values():
- d1.errback(f)
+ cache_entry.error_bulk(cache, missing, f)
args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = missing
+ args_to_call[self.list_name] = {
+ cache_key_to_arg(key) for key in missing
+ }
# dispatch the call, and attach the two handlers
- defer.maybeDeferred(
+ missing_d = defer.maybeDeferred(
preserve_fn(self.orig), **args_to_call
).addCallbacks(complete_all, errback_all)
+ cached_defers.append(missing_d)
if cached_defers:
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
@@ -525,7 +421,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
else:
return defer.succeed(results)
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -577,7 +473,8 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
-) -> Callable[[F], _CachedFunction[F]]:
+ name: Optional[str] = None,
+) -> Callable[[F], CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
@@ -587,21 +484,26 @@ def cached(
cache_context=cache_context,
iterable=iterable,
prune_unread_entries=prune_unread_entries,
+ name=name,
)
- return cast(Callable[[F], _CachedFunction[F]], func)
+ return cast(Callable[[F], CachedFunction[F]], func)
def cachedList(
- *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
-) -> Callable[[F], _CachedFunction[F]]:
+ *,
+ cached_method_name: str,
+ list_name: str,
+ num_args: Optional[int] = None,
+ name: Optional[str] = None,
+) -> Callable[[F], CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
Used to do batch lookups for an already created cache. One of the arguments
is specified as a list that is iterated through to lookup keys in the
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
the cache gets passed to the original function, which is expected to results
- in a map of key to value for each passed value. THe new results are stored in the
+ in a map of key to value for each passed value. The new results are stored in the
original cache. Note that any missing values are cached as None.
Args:
@@ -628,9 +530,10 @@ def cachedList(
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
+ name=name,
)
- return cast(Callable[[F], _CachedFunction[F]], func)
+ return cast(Callable[[F], CachedFunction[F]], func)
def _get_cache_key_builder(
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index fa91479c97..5eaf70c7ab 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -169,10 +169,11 @@ class DictionaryCache(Generic[KT, DKT, DV]):
if it is in the cache.
Returns:
- DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry`
- will contain include the keys that are in the cache. If None then
- will either return the full dict if in the cache, or the empty
- dict (with `full` set to False) if it isn't.
+ If `dict_keys` is not None then `DictionaryEntry` will contain include
+ the keys that are in the cache.
+
+ If None then will either return the full dict if in the cache, or the
+ empty dict (with `full` set to False) if it isn't.
"""
if dict_keys is None:
# The caller wants the full set of dictionary keys for this cache key
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index c6a5d0dfc0..01ad02af67 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -207,7 +207,7 @@ class ExpiringCache(Generic[KT, VT]):
items from the cache.
Returns:
- bool: Whether the cache changed size or not.
+ Whether the cache changed size or not.
"""
new_size = int(self._original_max_size * factor)
if new_size != self._max_size:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index b3bdedb04c..dcf0eac3bf 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -389,11 +389,11 @@ class LruCache(Generic[KT, VT]):
cache_name: The name of this cache, for the prometheus metrics. If unset,
no metrics will be reported on this cache.
- cache_type (type):
+ cache_type:
type of underlying cache to be used. Typically one of dict
or TreeCache.
- size_callback (func(V) -> int | None):
+ size_callback:
metrics_collection_callback:
metrics collection callback. This is called early in the metrics
@@ -403,7 +403,7 @@ class LruCache(Generic[KT, VT]):
Ignored if cache_name is None.
- apply_cache_factor_from_config (bool): If true, `max_size` will be
+ apply_cache_factor_from_config: If true, `max_size` will be
multiplied by a cache factor derived from the homeserver config
clock:
@@ -796,7 +796,7 @@ class LruCache(Generic[KT, VT]):
items from the cache.
Returns:
- bool: Whether the cache changed size or not.
+ Whether the cache changed size or not.
"""
if not self.apply_cache_factor_from_config:
return False
@@ -834,9 +834,26 @@ class AsyncLruCache(Generic[KT, VT]):
) -> Optional[VT]:
return self._lru_cache.get(key, update_metrics=update_metrics)
+ async def get_external(
+ self,
+ key: KT,
+ default: Optional[T] = None,
+ update_metrics: bool = True,
+ ) -> Optional[VT]:
+ # This method should fetch from any configured external cache, in this case noop.
+ return None
+
+ def get_local(
+ self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+ ) -> Optional[VT]:
+ return self._lru_cache.get(key, update_metrics=update_metrics)
+
async def set(self, key: KT, value: VT) -> None:
self._lru_cache.set(key, value)
+ def set_local(self, key: KT, value: VT) -> None:
+ self._lru_cache.set(key, value)
+
async def invalidate(self, key: KT) -> None:
# This method should invalidate any external cache and then invalidate the LruCache.
return self._lru_cache.invalidate(key)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 330709b8b7..666f4b6895 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -72,7 +72,7 @@ class StreamChangeCache:
items from the cache.
Returns:
- bool: Whether the cache changed size or not.
+ Whether the cache changed size or not.
"""
new_size = math.floor(self._original_max_size * factor)
if new_size != self._max_size:
@@ -188,14 +188,8 @@ class StreamChangeCache:
self._entity_to_key[entity] = stream_pos
self._evict()
- # if the cache is too big, remove entries
- while len(self._cache) > self._max_size:
- k, r = self._cache.popitem(0)
- self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
- for entity in r:
- del self._entity_to_key[entity]
-
def _evict(self) -> None:
+ # if the cache is too big, remove entries
while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
@@ -203,7 +197,6 @@ class StreamChangeCache:
self._entity_to_key.pop(entity, None)
def get_max_pos_of_last_change(self, entity: EntityType) -> int:
-
"""Returns an upper bound of the stream id of the last change to an
entity.
"""
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index c1b8ec0c73..fec31da2b6 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -135,6 +135,9 @@ class TreeCache:
def values(self):
return iterate_tree_cache_entry(self.root)
+ def items(self):
+ return iterate_tree_cache_items((), self.root)
+
def __len__(self) -> int:
return self.size
diff --git a/synapse/util/cancellation.py b/synapse/util/cancellation.py
new file mode 100644
index 0000000000..472d2e3aeb
--- /dev/null
+++ b/synapse/util/cancellation.py
@@ -0,0 +1,56 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Callable, TypeVar
+
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+def cancellable(function: F) -> F:
+ """Marks a function as cancellable.
+
+ Servlet methods with this decorator will be cancelled if the client disconnects before we
+ finish processing the request.
+
+ Although this annotation is particularly useful for servlet methods, it's also
+ useful for intermediate functions, where it documents the fact that the function has
+ been audited for cancellation safety and needs to preserve that.
+ This then simplifies auditing new functions that call those same intermediate
+ functions.
+
+ During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
+ the method. The `cancel()` call will propagate down to the `Deferred` that is
+ currently being waited on. That `Deferred` will raise a `CancelledError`, which will
+ propagate up, as per normal exception handling.
+
+ Before applying this decorator to a new function, you MUST recursively check
+ that all `await`s in the function are on `async` functions or `Deferred`s that
+ handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
+ premature logging context closure, to stuck requests, to database corruption.
+
+ See the documentation page on Cancellation for more information.
+
+ Usage:
+ class SomeServlet(RestServlet):
+ @cancellable
+ async def on_GET(self, request: SynapseRequest) -> ...:
+ ...
+ """
+
+ function.cancellable = True # type: ignore[attr-defined]
+ return function
+
+
+def is_function_cancellable(function: Callable[..., Any]) -> bool:
+ """Checks whether a servlet method has the `@cancellable` flag."""
+ return getattr(function, "cancellable", False)
diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
index 66f1da7502..3b1e205700 100644
--- a/synapse/util/check_dependencies.py
+++ b/synapse/util/check_dependencies.py
@@ -66,6 +66,21 @@ def _is_dev_dependency(req: Requirement) -> bool:
)
+def _should_ignore_runtime_requirement(req: Requirement) -> bool:
+ # This is a build-time dependency. Irritatingly, `poetry build` ignores the
+ # requirements listed in the [build-system] section of pyproject.toml, so in order
+ # to support `poetry install --no-dev` we have to mark it as a runtime dependency.
+ # See discussion on https://github.com/python-poetry/poetry/issues/6154 (it sounds
+ # like the poetry authors don't consider this a bug?)
+ #
+ # In any case, workaround this by ignoring setuptools_rust here. (It might be
+ # slightly cleaner to put `setuptools_rust` in a `build` extra or similar, but for
+ # now let's do something quick and dirty.
+ if req.name == "setuptools_rust":
+ return True
+ return False
+
+
class Dependency(NamedTuple):
requirement: Requirement
must_be_installed: bool
@@ -77,7 +92,7 @@ def _generic_dependencies() -> Iterable[Dependency]:
assert requirements is not None
for raw_requirement in requirements:
req = Requirement(raw_requirement)
- if _is_dev_dependency(req):
+ if _is_dev_dependency(req) or _should_ignore_runtime_requirement(req):
continue
# https://packaging.pypa.io/en/latest/markers.html#usage notes that
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index df77edcce2..5df03d3ddc 100644
--- a/synapse/util/macaroons.py
+++ b/synapse/util/macaroons.py
@@ -24,7 +24,7 @@ from typing_extensions import Literal
from synapse.util import Clock, stringutils
-MacaroonType = Literal["access", "delete_pusher", "session", "login"]
+MacaroonType = Literal["access", "delete_pusher", "session"]
def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
@@ -111,19 +111,6 @@ class OidcSessionData:
"""The session ID of the ongoing UI Auth ("" if this is a login)"""
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class LoginTokenAttributes:
- """Data we store in a short-term login token"""
-
- user_id: str
-
- auth_provider_id: str
- """The SSO Identity Provider that the user authenticated with, to get this token."""
-
- auth_provider_session_id: Optional[str]
- """The session ID advertised by the SSO Identity Provider."""
-
-
class MacaroonGenerator:
def __init__(self, clock: Clock, location: str, secret_key: bytes):
self._clock = clock
@@ -165,35 +152,6 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat(f"pushkey = {pushkey}")
return macaroon.serialize()
- def generate_short_term_login_token(
- self,
- user_id: str,
- auth_provider_id: str,
- auth_provider_session_id: Optional[str] = None,
- duration_in_ms: int = (2 * 60 * 1000),
- ) -> str:
- """Generate a short-term login token used during SSO logins
-
- Args:
- user_id: The user for which the token is valid.
- auth_provider_id: The SSO IdP the user used.
- auth_provider_session_id: The session ID got during login from the SSO IdP.
-
- Returns:
- A signed token valid for using as a ``m.login.token`` token.
- """
- now = self._clock.time_msec()
- expiry = now + duration_in_ms
- macaroon = self._generate_base_macaroon("login")
- macaroon.add_first_party_caveat(f"user_id = {user_id}")
- macaroon.add_first_party_caveat(f"time < {expiry}")
- macaroon.add_first_party_caveat(f"auth_provider_id = {auth_provider_id}")
- if auth_provider_session_id is not None:
- macaroon.add_first_party_caveat(
- f"auth_provider_session_id = {auth_provider_session_id}"
- )
- return macaroon.serialize()
-
def generate_oidc_session_token(
self,
state: str,
@@ -233,49 +191,6 @@ class MacaroonGenerator:
return macaroon.serialize()
- def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
- """Verify a short-term-login macaroon
-
- Checks that the given token is a valid, unexpired short-term-login token
- minted by this server.
-
- Args:
- token: The login token to verify.
-
- Returns:
- A set of attributes carried by this token, including the
- ``user_id`` and informations about the SSO IDP used during that
- login.
-
- Raises:
- MacaroonVerificationFailedException if the verification failed
- """
- macaroon = pymacaroons.Macaroon.deserialize(token)
-
- v = self._base_verifier("login")
- v.satisfy_general(lambda c: c.startswith("user_id = "))
- v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
- v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
- satisfy_expiry(v, self._clock.time_msec)
- v.verify(macaroon, self._secret_key)
-
- user_id = get_value_from_macaroon(macaroon, "user_id")
- auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
-
- auth_provider_session_id: Optional[str] = None
- try:
- auth_provider_session_id = get_value_from_macaroon(
- macaroon, "auth_provider_session_id"
- )
- except MacaroonVerificationFailedException:
- pass
-
- return LoginTokenAttributes(
- user_id=user_id,
- auth_provider_id=auth_provider_id,
- auth_provider_session_id=auth_provider_session_id,
- )
-
def verify_guest_token(self, token: str) -> str:
"""Verify a guest access token macaroon
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index bc3b4938ea..165480bdbe 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -15,9 +15,9 @@
import logging
from functools import wraps
from types import TracebackType
-from typing import Awaitable, Callable, Optional, Type, TypeVar
+from typing import Awaitable, Callable, Dict, Generator, Optional, Type, TypeVar
-from prometheus_client import Counter
+from prometheus_client import CollectorRegistry, Counter, Metric
from typing_extensions import Concatenate, ParamSpec, Protocol
from synapse.logging.context import (
@@ -208,3 +208,33 @@ class Measure:
metrics.real_time_sum += duration
# TODO: Add other in flight metrics.
+
+
+class DynamicCollectorRegistry(CollectorRegistry):
+ """
+ Custom Prometheus Collector registry that calls a hook first, allowing you
+ to update metrics on-demand.
+
+ Don't forget to register this registry with the main registry!
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self._pre_update_hooks: Dict[str, Callable[[], None]] = {}
+
+ def collect(self) -> Generator[Metric, None, None]:
+ """
+ Collects metrics, calling pre-update hooks first.
+ """
+
+ for pre_update_hook in self._pre_update_hooks.values():
+ pre_update_hook()
+
+ yield from super().collect()
+
+ def register_hook(self, metric_name: str, hook: Callable[[], None]) -> None:
+ """
+ Registers a hook that is called before metric collection.
+ """
+
+ self._pre_update_hooks[metric_name] = hook
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 6394cc39ac..2aceb1a47f 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -15,8 +15,23 @@
import collections
import contextlib
import logging
+import threading
import typing
-from typing import Any, DefaultDict, Iterator, List, Set
+from typing import (
+ Any,
+ Callable,
+ DefaultDict,
+ Dict,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+)
+
+from prometheus_client.core import Counter
+from typing_extensions import ContextManager
from twisted.internet import defer
@@ -27,6 +42,8 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
+from synapse.logging.opentracing import start_active_span
+from synapse.metrics import Histogram, LaterGauge
from synapse.util import Clock
if typing.TYPE_CHECKING:
@@ -35,15 +52,127 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# Track how much the ratelimiter is affecting requests
+rate_limit_sleep_counter = Counter(
+ "synapse_rate_limit_sleep",
+ "Number of requests slept by the rate limiter",
+ ["rate_limiter_name"],
+)
+rate_limit_reject_counter = Counter(
+ "synapse_rate_limit_reject",
+ "Number of requests rejected by the rate limiter",
+ ["rate_limiter_name"],
+)
+queue_wait_timer = Histogram(
+ "synapse_rate_limit_queue_wait_time_seconds",
+ "Amount of time spent waiting for the rate limiter to let our request through.",
+ ["rate_limiter_name"],
+ buckets=(
+ 0.005,
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.1,
+ 0.25,
+ 0.5,
+ 0.75,
+ 1.0,
+ 2.5,
+ 5.0,
+ 10.0,
+ 20.0,
+ "+Inf",
+ ),
+)
+
+
+_rate_limiter_instances: Set["FederationRateLimiter"] = set()
+# Protects the _rate_limiter_instances set from concurrent access
+_rate_limiter_instances_lock = threading.Lock()
+
+
+def _get_counts_from_rate_limiter_instance(
+ count_func: Callable[["FederationRateLimiter"], int]
+) -> Mapping[Tuple[str, ...], int]:
+ """Returns a count of something (slept/rejected hosts) by (metrics_name)"""
+ # Cast to a list to prevent it changing while the Prometheus
+ # thread is collecting metrics
+ with _rate_limiter_instances_lock:
+ rate_limiter_instances = list(_rate_limiter_instances)
+
+ # Map from (metrics_name,) -> int, the number of something like slept hosts
+ # or rejected hosts. The key type is Tuple[str], but we leave the length
+ # unspecified for compatability with LaterGauge's annotations.
+ counts: Dict[Tuple[str, ...], int] = {}
+ for rate_limiter_instance in rate_limiter_instances:
+ # Only track metrics if they provided a `metrics_name` to
+ # differentiate this instance of the rate limiter.
+ if rate_limiter_instance.metrics_name:
+ key = (rate_limiter_instance.metrics_name,)
+ counts[key] = count_func(rate_limiter_instance)
+
+ return counts
+
+
+# We track the number of affected hosts per time-period so we can
+# differentiate one really noisy homeserver from a general
+# ratelimit tuning problem across the federation.
+LaterGauge(
+ "synapse_rate_limit_sleep_affected_hosts",
+ "Number of hosts that had requests put to sleep",
+ ["rate_limiter_name"],
+ lambda: _get_counts_from_rate_limiter_instance(
+ lambda rate_limiter_instance: sum(
+ ratelimiter.should_sleep()
+ for ratelimiter in rate_limiter_instance.ratelimiters.values()
+ )
+ ),
+)
+LaterGauge(
+ "synapse_rate_limit_reject_affected_hosts",
+ "Number of hosts that had requests rejected",
+ ["rate_limiter_name"],
+ lambda: _get_counts_from_rate_limiter_instance(
+ lambda rate_limiter_instance: sum(
+ ratelimiter.should_reject()
+ for ratelimiter in rate_limiter_instance.ratelimiters.values()
+ )
+ ),
+)
+
+
class FederationRateLimiter:
- def __init__(self, clock: Clock, config: FederationRatelimitSettings):
+ """Used to rate limit request per-host."""
+
+ def __init__(
+ self,
+ clock: Clock,
+ config: FederationRatelimitSettings,
+ metrics_name: Optional[str] = None,
+ ):
+ """
+ Args:
+ clock
+ config
+ metrics_name: The name of the rate limiter so we can differentiate it
+ from the rest in the metrics. If `None`, we don't track metrics
+ for this rate limiter.
+
+ """
+ self.metrics_name = metrics_name
+
def new_limiter() -> "_PerHostRatelimiter":
- return _PerHostRatelimiter(clock=clock, config=config)
+ return _PerHostRatelimiter(
+ clock=clock, config=config, metrics_name=metrics_name
+ )
self.ratelimiters: DefaultDict[
str, "_PerHostRatelimiter"
] = collections.defaultdict(new_limiter)
+ with _rate_limiter_instances_lock:
+ _rate_limiter_instances.add(self)
+
def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
"""Used to ratelimit an incoming request from a given host
@@ -54,22 +183,32 @@ class FederationRateLimiter:
# Handle request ...
Args:
- host (str): Origin of incoming request.
+ host: Origin of incoming request.
Returns:
context manager which returns a deferred.
"""
- return self.ratelimiters[host].ratelimit()
+ return self.ratelimiters[host].ratelimit(host)
class _PerHostRatelimiter:
- def __init__(self, clock: Clock, config: FederationRatelimitSettings):
+ def __init__(
+ self,
+ clock: Clock,
+ config: FederationRatelimitSettings,
+ metrics_name: Optional[str] = None,
+ ):
"""
Args:
clock
config
+ metrics_name: The name of the rate limiter so we can differentiate it
+ from the rest in the metrics. If `None`, we don't track metrics
+ for this rate limiter.
+ from the rest in the metrics
"""
self.clock = clock
+ self.metrics_name = metrics_name
self.window_size = config.window_size
self.sleep_limit = config.sleep_limit
@@ -94,19 +233,45 @@ class _PerHostRatelimiter:
self.request_times: List[int] = []
@contextlib.contextmanager
- def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
+ def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]":
# `contextlib.contextmanager` takes a generator and turns it into a
# context manager. The generator should only yield once with a value
# to be returned by manager.
# Exceptions will be reraised at the yield.
+ self.host = host
+
request_id = object()
- ret = self._on_enter(request_id)
+ # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
+ # type-checking, but we'd need Twisted >= 21.2.
+ ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id))
try:
yield ret
finally:
self._on_exit(request_id)
+ def should_reject(self) -> bool:
+ """
+ Whether to reject the request if we already have too many queued up
+ (either sleeping or in the ready queue).
+ """
+ queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
+ return queue_size > self.reject_limit
+
+ def should_sleep(self) -> bool:
+ """
+ Whether to sleep the request if we already have too many requests coming
+ through within the window.
+ """
+ return len(self.request_times) > self.sleep_limit
+
+ async def _on_enter_with_tracing(self, request_id: object) -> None:
+ maybe_metrics_cm: ContextManager = contextlib.nullcontext()
+ if self.metrics_name:
+ maybe_metrics_cm = queue_wait_timer.labels(self.metrics_name).time()
+ with start_active_span("ratelimit wait"), maybe_metrics_cm:
+ await self._on_enter(request_id)
+
def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
time_now = self.clock.time_msec()
@@ -117,8 +282,10 @@ class _PerHostRatelimiter:
# reject the request if we already have too many queued up (either
# sleeping or in the ready queue).
- queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
- if queue_size > self.reject_limit:
+ if self.should_reject():
+ logger.debug("Ratelimiter(%s): rejecting request", self.host)
+ if self.metrics_name:
+ rate_limit_reject_counter.labels(self.metrics_name).inc()
raise LimitExceededError(
retry_after_ms=int(self.window_size / self.sleep_limit)
)
@@ -130,7 +297,8 @@ class _PerHostRatelimiter:
queue_defer: defer.Deferred[None] = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
logger.info(
- "Ratelimiter: queueing request (queue now %i items)",
+ "Ratelimiter(%s): queueing request (queue now %i items)",
+ self.host,
len(self.ready_request_queue),
)
@@ -139,19 +307,29 @@ class _PerHostRatelimiter:
return defer.succeed(None)
logger.debug(
- "Ratelimit [%s]: len(self.request_times)=%d",
+ "Ratelimit(%s) [%s]: len(self.request_times)=%d",
+ self.host,
id(request_id),
len(self.request_times),
)
- if len(self.request_times) > self.sleep_limit:
- logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
+ if self.should_sleep():
+ logger.debug(
+ "Ratelimiter(%s) [%s]: sleeping request for %f sec",
+ self.host,
+ id(request_id),
+ self.sleep_sec,
+ )
+ if self.metrics_name:
+ rate_limit_sleep_counter.labels(self.metrics_name).inc()
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
def on_wait_finished(_: Any) -> "defer.Deferred[None]":
- logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
+ logger.debug(
+ "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id)
+ )
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
@@ -161,7 +339,9 @@ class _PerHostRatelimiter:
ret_defer = queue_request()
def on_start(r: object) -> object:
- logger.debug("Ratelimit [%s]: Processing req", id(request_id))
+ logger.debug(
+ "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id)
+ )
self.current_processing.add(request_id)
return r
@@ -183,7 +363,7 @@ class _PerHostRatelimiter:
return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id: object) -> None:
- logger.debug("Ratelimit [%s]: Processed req", id(request_id))
+ logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
self.current_processing.discard(request_id)
try:
# start processing the next item on the queue.
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index d0a69ff843..dcc037b982 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -51,7 +51,7 @@ class NotRetryingDestination(Exception):
destination: the domain in question
"""
- msg = "Not retrying server %s." % (destination,)
+ msg = f"Not retrying server {destination} because we tried it recently retry_last_ts={retry_last_ts} and we won't check for another retry_interval={retry_interval}ms."
super().__init__(msg)
self.retry_last_ts = retry_last_ts
diff --git a/synapse/util/rust.py b/synapse/util/rust.py
new file mode 100644
index 0000000000..30ecb9ffd9
--- /dev/null
+++ b/synapse/util/rust.py
@@ -0,0 +1,84 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+from hashlib import blake2b
+
+import synapse
+from synapse.synapse_rust import get_rust_file_digest
+
+
+def check_rust_lib_up_to_date() -> None:
+ """For editable installs check if the rust library is outdated and needs to
+ be rebuilt.
+ """
+
+ if not _dist_is_editable():
+ return
+
+ synapse_dir = os.path.dirname(synapse.__file__)
+ synapse_root = os.path.abspath(os.path.join(synapse_dir, ".."))
+
+ # Double check we've not gone into site-packages...
+ if os.path.basename(synapse_root) == "site-packages":
+ return
+
+ # ... and it looks like the root of a python project.
+ if not os.path.exists("pyproject.toml"):
+ return
+
+ # Get the hash of all Rust source files
+ hash = _hash_rust_files_in_directory(os.path.join(synapse_root, "rust", "src"))
+
+ if hash != get_rust_file_digest():
+ raise Exception("Rust module outdated. Please rebuild using `poetry install`")
+
+
+def _hash_rust_files_in_directory(directory: str) -> str:
+ """Get the hash of all files in a directory (recursively)"""
+
+ directory = os.path.abspath(directory)
+
+ paths = []
+
+ dirs = [directory]
+ while dirs:
+ dir = dirs.pop()
+ with os.scandir(dir) as d:
+ for entry in d:
+ if entry.is_dir():
+ dirs.append(entry.path)
+ else:
+ paths.append(entry.path)
+
+ # We sort to make sure that we get a consistent and well-defined ordering.
+ paths.sort()
+
+ hasher = blake2b()
+
+ for path in paths:
+ with open(os.path.join(directory, path), "rb") as f:
+ hasher.update(f.read())
+
+ return hasher.hexdigest()
+
+
+def _dist_is_editable() -> bool:
+ """Is distribution an editable install?"""
+ for path_item in sys.path:
+ egg_link = os.path.join(path_item, "matrix-synapse.egg-link")
+ if os.path.isfile(egg_link):
+ return True
+ return False
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 27a363d7e5..4961fe9313 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -86,7 +86,7 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
ValueError if the server name could not be parsed.
"""
try:
- if server_name[-1] == "]":
+ if server_name and server_name[-1] == "]":
# ipv6 literal, hopefully
return server_name, None
@@ -123,7 +123,7 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
- if host[0] == "[":
+ if host and host[0] == "[":
if host[-1] != "]":
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 1e9c2faa64..54bc7589fd 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -48,7 +48,7 @@ async def check_3pid_allowed(
registration: whether we want to bind the 3PID as part of registering a new user.
Returns:
- bool: whether the 3PID medium/address is allowed to be added to this HS
+ whether the 3PID medium/address is allowed to be added to this HS
"""
if not await hs.get_password_auth_provider().is_3pid_allowed(
medium, address, registration
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 177e198e7e..b1ec7f4bd8 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -90,10 +90,10 @@ class WheelTimer(Generic[T]):
"""Fetch any objects that have timed out
Args:
- now (ms): Current time in msec
+ now: Current time in msec
Returns:
- list: List of objects that have timed out
+ List of objects that have timed out
"""
now_key = int(now / self.bucket_size)
|