summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py18
-rw-r--r--synapse/util/async_helpers.py188
-rw-r--r--synapse/util/caches/__init__.py4
-rw-r--r--synapse/util/caches/descriptors.py102
-rw-r--r--synapse/util/caches/dictionary_cache.py4
-rw-r--r--synapse/util/caches/expiringcache.py4
-rw-r--r--synapse/util/caches/lrucache.py4
-rw-r--r--synapse/util/caches/response_cache.py2
-rw-r--r--synapse/util/caches/treecache.py4
-rw-r--r--synapse/util/caches/ttlcache.py4
-rw-r--r--synapse/util/daemonize.py137
-rw-r--r--synapse/util/distributor.py54
-rw-r--r--synapse/util/file_consumer.py2
-rw-r--r--synapse/util/frozenutils.py10
-rw-r--r--synapse/util/jsonobject.py2
-rw-r--r--synapse/util/manhole.py2
-rw-r--r--synapse/util/metrics.py79
-rw-r--r--synapse/util/patch_inline_callbacks.py2
-rw-r--r--synapse/util/ratelimitutils.py4
-rw-r--r--synapse/util/retryutils.py20
-rw-r--r--synapse/util/stringutils.py4
-rw-r--r--synapse/util/wheel_timer.py4
22 files changed, 392 insertions, 262 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py

index c63256d3bd..d55b93d763 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py
@@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import re @@ -25,14 +26,27 @@ from synapse.logging import context logger = logging.getLogger(__name__) +def _reject_invalid_json(val): + """Do not allow Infinity, -Infinity, or NaN values in JSON.""" + raise ValueError("Invalid JSON value: '%s'" % val) + + +# Create a custom encoder to reduce the whitespace produced by JSON encoding and +# ensure that valid JSON is produced. +json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":")) + +# Create a custom decoder to reject Python extensions to JSON. +json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) + + def unwrapFirstError(failure): # defer.gatherResults and DeferredLists wrap failures. failure.trap(defer.FirstError) return failure.value.subFailure -@attr.s -class Clock(object): +@attr.s(slots=True) +class Clock: """ A Clock wraps a Twisted reactor and provides utilities on top of it. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f562770922..382f0cf3f0 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -17,12 +17,25 @@ import collections import logging from contextlib import contextmanager -from typing import Dict, Sequence, Set, Union +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Optional, + Set, + TypeVar, + Union, +) import attr +from typing_extensions import ContextManager from twisted.internet import defer from twisted.internet.defer import CancelledError +from twisted.internet.interfaces import IReactorTime from twisted.python import failure from synapse.logging.context import ( @@ -35,7 +48,7 @@ from synapse.util import Clock, unwrapFirstError logger = logging.getLogger(__name__) -class ObservableDeferred(object): +class ObservableDeferred: """Wraps a deferred object so that we can add observer deferreds. These observer deferreds do not affect the callback chain of the original deferred. @@ -53,7 +66,7 @@ class ObservableDeferred(object): __slots__ = ["_deferred", "_observers", "_result"] - def __init__(self, deferred, consumeErrors=False): + def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) object.__setattr__(self, "_observers", set()) @@ -110,25 +123,25 @@ class ObservableDeferred(object): success, res = self._result return defer.succeed(res) if success else defer.fail(res) - def observers(self): + def observers(self) -> List[defer.Deferred]: return self._observers - def has_called(self): + def has_called(self) -> bool: return self._result is not None - def has_succeeded(self): + def has_succeeded(self) -> bool: return self._result is not None and self._result[0] is True - def get_result(self): + def get_result(self) -> Any: return self._result[1] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self._deferred, name) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: setattr(self._deferred, name, value) - def __repr__(self): + def __repr__(self) -> str: return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( id(self), self._result, @@ -136,18 +149,20 @@ class ObservableDeferred(object): ) -def concurrently_execute(func, args, limit): - """Executes the function with each argument conncurrently while limiting +def concurrently_execute( + func: Callable, args: Iterable[Any], limit: int +) -> defer.Deferred: + """Executes the function with each argument concurrently while limiting the number of concurrent executions. Args: - func (func): Function to execute, should return a deferred or coroutine. - args (Iterable): List of arguments to pass to func, each invocation of func + func: Function to execute, should return a deferred or coroutine. + args: List of arguments to pass to func, each invocation of func gets a single argument. - limit (int): Maximum number of conccurent executions. + limit: Maximum number of conccurent executions. Returns: - deferred: Resolved when all function invocations have finished. + Deferred[list]: Resolved when all function invocations have finished. """ it = iter(args) @@ -166,14 +181,17 @@ def concurrently_execute(func, args, limit): ).addErrback(unwrapFirstError) -def yieldable_gather_results(func, iter, *args, **kwargs): +def yieldable_gather_results( + func: Callable, iter: Iterable, *args: Any, **kwargs: Any +) -> defer.Deferred: """Executes the function with each argument concurrently. Args: - func (func): Function to execute that returns a Deferred - iter (iter): An iterable that yields items that get passed as the first + func: Function to execute that returns a Deferred + iter: An iterable that yields items that get passed as the first argument to the function *args: Arguments to be passed to each call to func + **kwargs: Keyword arguments to be passed to each call to func Returns Deferred[list]: Resolved when all functions have been invoked, or errors if @@ -187,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs): ).addErrback(unwrapFirstError) -class Linearizer(object): +@attr.s(slots=True) +class _LinearizerEntry: + # The number of things executing. + count = attr.ib(type=int) + # Deferreds for the things blocked from executing. + deferreds = attr.ib(type=collections.OrderedDict) + + +class Linearizer: """Limits concurrent access to resources based on a key. Useful to ensure only a few things happen at a time on a given resource. Example: - with (yield limiter.queue("test_key")): + with await limiter.queue("test_key"): # do some work. """ - def __init__(self, name=None, max_count=1, clock=None): + def __init__( + self, + name: Optional[str] = None, + max_count: int = 1, + clock: Optional[Clock] = None, + ): """ Args: - max_count(int): The maximum number of concurrent accesses + max_count: The maximum number of concurrent accesses """ if name is None: - self.name = id(self) + self.name = id(self) # type: Union[str, int] else: self.name = name @@ -215,15 +246,10 @@ class Linearizer(object): self._clock = clock self.max_count = max_count - # key_to_defer is a map from the key to a 2 element list where - # the first element is the number of things executing, and - # the second element is an OrderedDict, where the keys are deferreds for the - # things blocked from executing. - self.key_to_defer = ( - {} - ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]] + # key_to_defer is a map from the key to a _LinearizerEntry. + self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry] - def is_queued(self, key) -> bool: + def is_queued(self, key: Hashable) -> bool: """Checks whether there is a process queued up waiting """ entry = self.key_to_defer.get(key) @@ -233,25 +259,27 @@ class Linearizer(object): # There are waiting deferreds only in the OrderedDict of deferreds is # non-empty. - return bool(entry[1]) + return bool(entry.deferreds) - def queue(self, key): + def queue(self, key: Hashable) -> defer.Deferred: # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not # propagated inside inlineCallbacks until Twisted 18.7) - entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) + entry = self.key_to_defer.setdefault( + key, _LinearizerEntry(0, collections.OrderedDict()) + ) # If the number of things executing is greater than the maximum # then add a deferred to the list of blocked items # When one of the things currently executing finishes it will callback # this item so that it can continue executing. - if entry[0] >= self.max_count: + if entry.count >= self.max_count: res = self._await_lock(key) else: logger.debug( "Acquired uncontended linearizer lock %r for key %r", self.name, key ) - entry[0] += 1 + entry.count += 1 res = defer.succeed(None) # once we successfully get the lock, we need to return a context manager which @@ -266,15 +294,15 @@ class Linearizer(object): # We've finished executing so check if there are any things # blocked waiting to execute and start one of them - entry[0] -= 1 + entry.count -= 1 - if entry[1]: - (next_def, _) = entry[1].popitem(last=False) + if entry.deferreds: + (next_def, _) = entry.deferreds.popitem(last=False) # we need to run the next thing in the sentinel context. with PreserveLoggingContext(): next_def.callback(None) - elif entry[0] == 0: + elif entry.count == 0: # We were the last thing for this key: remove it from the # map. del self.key_to_defer[key] @@ -282,7 +310,7 @@ class Linearizer(object): res.addCallback(_ctx_manager) return res - def _await_lock(self, key): + def _await_lock(self, key: Hashable) -> defer.Deferred: """Helper for queue: adds a deferred to the queue Assumes that we've already checked that we've reached the limit of the number @@ -297,11 +325,11 @@ class Linearizer(object): logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) new_defer = make_deferred_yieldable(defer.Deferred()) - entry[1][new_defer] = 1 + entry.deferreds[new_defer] = 1 def cb(_r): logger.debug("Acquired linearizer lock %r for key %r", self.name, key) - entry[0] += 1 + entry.count += 1 # if the code holding the lock completes synchronously, then it # will recursively run the next claimant on the list. That can @@ -330,19 +358,19 @@ class Linearizer(object): ) # we just have to take ourselves back out of the queue. - del entry[1][new_defer] + del entry.deferreds[new_defer] return e new_defer.addCallbacks(cb, eb) return new_defer -class ReadWriteLock(object): - """A deferred style read write lock. +class ReadWriteLock: + """An async read write lock. Example: - with (yield read_write_lock.read("test_key")): + with await read_write_lock.read("test_key"): # do some work """ @@ -365,8 +393,7 @@ class ReadWriteLock(object): # Latest writer queued self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] - @defer.inlineCallbacks - def read(self, key): + async def read(self, key: str) -> ContextManager: new_defer = defer.Deferred() curr_readers = self.key_to_current_readers.setdefault(key, set()) @@ -376,7 +403,8 @@ class ReadWriteLock(object): # We wait for the latest writer to finish writing. We can safely ignore # any existing readers... as they're readers. - yield make_deferred_yieldable(curr_writer) + if curr_writer: + await make_deferred_yieldable(curr_writer) @contextmanager def _ctx_manager(): @@ -388,8 +416,7 @@ class ReadWriteLock(object): return _ctx_manager() - @defer.inlineCallbacks - def write(self, key): + async def write(self, key: str) -> ContextManager: new_defer = defer.Deferred() curr_readers = self.key_to_current_readers.get(key, set()) @@ -405,7 +432,7 @@ class ReadWriteLock(object): curr_readers.clear() self.key_to_current_writer[key] = new_defer - yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) + await make_deferred_yieldable(defer.gatherResults(to_wait_on)) @contextmanager def _ctx_manager(): @@ -419,14 +446,12 @@ class ReadWriteLock(object): return _ctx_manager() -def _cancelled_to_timed_out_error(value, timeout): - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise defer.TimeoutError(timeout, "Deferred") - return value +R = TypeVar("R") -def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): +def timeout_deferred( + deferred: defer.Deferred, timeout: float, reactor: IReactorTime, +) -> defer.Deferred: """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new deferred that wraps and times out the given deferred, correctly handling @@ -434,27 +459,21 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): (See https://twistedmatrix.com/trac/ticket/9534) - NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred + NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred. - Args: - deferred (Deferred) - timeout (float): Timeout in seconds - reactor (twisted.interfaces.IReactorTime): The twisted reactor to use - on_timeout_cancel (callable): A callable which is called immediately - after the deferred times out, and not if this deferred is - otherwise cancelled before the timeout. + NOTE: the TimeoutError raised by the resultant deferred is + twisted.internet.defer.TimeoutError, which is *different* to the built-in + TimeoutError, as well as various other TimeoutErrors you might have imported. - It takes an arbitrary value, which is the value of the deferred at - that exact point in time (probably a CancelledError Failure), and - the timeout. + Args: + deferred: The Deferred to potentially timeout. + timeout: Timeout in seconds + reactor: The twisted reactor to use - The default callable (if none is provided) will translate a - CancelledError Failure into a defer.TimeoutError. Returns: - Deferred + A new Deferred, which will errback with defer.TimeoutError on timeout. """ - new_d = defer.Deferred() timed_out = [False] @@ -467,18 +486,23 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): except: # noqa: E722, if we throw any exception it'll break time outs logger.exception("Canceller failed during timeout") + # the cancel() call should have set off a chain of errbacks which + # will have errbacked new_d, but in case it hasn't, errback it now. + if not new_d.called: - new_d.errback(defer.TimeoutError(timeout, "Deferred")) + new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,))) delayed_call = reactor.callLater(timeout, time_it_out) - def convert_cancelled(value): - if timed_out[0]: - to_call = on_timeout_cancel or _cancelled_to_timed_out_error - return to_call(value, timeout) + def convert_cancelled(value: failure.Failure): + # if the orgininal deferred was cancelled, and our timeout has fired, then + # the reason it was cancelled was due to our timeout. Turn the CancelledError + # into a TimeoutError. + if timed_out[0] and value.check(CancelledError): + raise defer.TimeoutError("Timed out after %gs" % (timeout,)) return value - deferred.addBoth(convert_cancelled) + deferred.addErrback(convert_cancelled) def cancel_timeout(result): # stop the pending call to cancel the deferred if it's been fired @@ -502,7 +526,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): @attr.s(slots=True, frozen=True) -class DoneAwaitable(object): +class DoneAwaitable: """Simple awaitable that returns the provided value. """ diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index dd356bf156..8fc05be278 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py
@@ -42,8 +42,8 @@ response_cache_evicted = Gauge( response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) -@attr.s -class CacheMetric(object): +@attr.s(slots=True) +class CacheMetric: _cache = attr.ib() _cache_type = attr.ib(type=str) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 9b09c08b89..98b34f2223 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py
@@ -18,11 +18,10 @@ import functools import inspect import logging import threading -from typing import Any, Tuple, Union, cast +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from weakref import WeakValueDictionary from prometheus_client import Gauge -from typing_extensions import Protocol from twisted.internet import defer @@ -38,8 +37,10 @@ logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] +F = TypeVar("F", bound=Callable[..., Any]) -class _CachedFunction(Protocol): + +class _CachedFunction(Generic[F]): invalidate = None # type: Any invalidate_all = None # type: Any invalidate_many = None # type: Any @@ -47,8 +48,11 @@ class _CachedFunction(Protocol): cache = None # type: Any num_args = None # type: Any - def __name__(self): - ... + __name__ = None # type: str + + # Note: This function signature is actually fiddled with by the synapse mypy + # plugin to a) make it a bound method, and b) remove any `cache_context` arg. + __call__ = None # type: F cache_pending_metric = Gauge( @@ -60,7 +64,7 @@ cache_pending_metric = Gauge( _CacheSentinel = object() -class CacheEntry(object): +class CacheEntry: __slots__ = ["deferred", "callbacks", "invalidated"] def __init__(self, deferred, callbacks): @@ -76,7 +80,7 @@ class CacheEntry(object): self.callbacks.clear() -class Cache(object): +class Cache: __slots__ = ( "cache", "name", @@ -123,7 +127,7 @@ class Cache(object): self.name = name self.keylen = keylen - self.thread = None + self.thread = None # type: Optional[threading.Thread] self.metrics = register_cache( "cache", name, @@ -192,7 +196,7 @@ class Cache(object): callbacks = [callback] if callback else [] self.check_thread() observable = ObservableDeferred(value, consumeErrors=True) - observer = defer.maybeDeferred(observable.observe) + observer = observable.observe() entry = CacheEntry(deferred=observable, callbacks=callbacks) existing_entry = self._pending_deferred_cache.pop(key, None) @@ -284,17 +288,10 @@ class Cache(object): self._pending_deferred_cache.clear() -class _CacheDescriptorBase(object): - def __init__( - self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False - ): +class _CacheDescriptorBase: + def __init__(self, orig: _CachedFunction, num_args, cache_context=False): self.orig = orig - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args @@ -364,7 +361,7 @@ class CacheDescriptor(_CacheDescriptorBase): invalidated) by adding a special "cache_context" argument to the function and passing that as a kwarg to all caches called. For example:: - @cachedInlineCallbacks(cache_context=True) + @cached(cache_context=True) def foo(self, key, cache_context): r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) @@ -382,17 +379,11 @@ class CacheDescriptor(_CacheDescriptorBase): max_entries=1000, num_args=None, tree=False, - inlineCallbacks=False, cache_context=False, iterable=False, ): - super(CacheDescriptor, self).__init__( - orig, - num_args=num_args, - inlineCallbacks=inlineCallbacks, - cache_context=cache_context, - ) + super().__init__(orig, num_args=num_args, cache_context=cache_context) self.max_entries = max_entries self.tree = tree @@ -465,9 +456,7 @@ class CacheDescriptor(_CacheDescriptorBase): observer = defer.succeed(cached_result_d) except KeyError: - ret = defer.maybeDeferred( - preserve_fn(self.function_to_call), obj, *args, **kwargs - ) + ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) def onErr(f): cache.invalidate(cache_key) @@ -510,9 +499,7 @@ class CacheListDescriptor(_CacheDescriptorBase): of results. """ - def __init__( - self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False - ): + def __init__(self, orig, cached_method_name, list_name, num_args=None): """ Args: orig (function) @@ -521,12 +508,8 @@ class CacheListDescriptor(_CacheDescriptorBase): num_args (int): number of positional arguments (excluding ``self``, but including list_name) to use as cache keys. Defaults to all named args of the function. - inlineCallbacks (bool): Whether orig is a generator that should - be wrapped by defer.inlineCallbacks """ - super(CacheListDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks - ) + super().__init__(orig, num_args=num_args) self.list_name = list_name @@ -631,7 +614,7 @@ class CacheListDescriptor(_CacheDescriptorBase): cached_defers.append( defer.maybeDeferred( - preserve_fn(self.function_to_call), **args_to_call + preserve_fn(self.orig), **args_to_call ).addCallbacks(complete_all, errback) ) @@ -683,9 +666,13 @@ class _CacheContext: def cached( - max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False -): - return lambda orig: CacheDescriptor( + max_entries: int = 1000, + num_args: Optional[int] = None, + tree: bool = False, + cache_context: bool = False, + iterable: bool = False, +) -> Callable[[F], _CachedFunction[F]]: + func = lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, @@ -694,22 +681,12 @@ def cached( iterable=iterable, ) - -def cachedInlineCallbacks( - max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False -): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - tree=tree, - inlineCallbacks=True, - cache_context=cache_context, - iterable=iterable, - ) + return cast(Callable[[F], _CachedFunction[F]], func) -def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False): +def cachedList( + cached_method_name: str, list_name: str, num_args: Optional[int] = None +) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -719,18 +696,16 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal cache. Args: - cached_method_name (str): The name of the single-item lookup method. + cached_method_name: The name of the single-item lookup method. This is only used to find the cache to use. - list_name (str): The name of the argument that is the list to use to + list_name: The name of the argument that is the list to use to do batch lookups in the cache. - num_args (int): Number of arguments to use as the key in the cache + num_args: Number of arguments to use as the key in the cache (including list_name). Defaults to all named parameters. - inlineCallbacks (bool): Should the function be wrapped in an - `defer.inlineCallbacks`? Example: - class Example(object): + class Example: @cached(num_args=2) def do_something(self, first_arg): ... @@ -739,10 +714,11 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal def batch_do_something(self, first_arg, second_args): ... """ - return lambda orig: CacheListDescriptor( + func = lambda orig: CacheListDescriptor( orig, cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, - inlineCallbacks=inlineCallbacks, ) + + return cast(Callable[[F], _CachedFunction[F]], func) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 6834e6f3ae..8592b93689 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py
@@ -40,7 +40,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va return len(self.value) -class DictionaryCache(object): +class DictionaryCache: """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ @@ -53,7 +53,7 @@ class DictionaryCache(object): self.thread = None # caches_by_name[name] = self.cache - class Sentinel(object): + class Sentinel: __slots__ = [] self.sentinel = Sentinel() diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 89a3420f92..e15f7ee698 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) SENTINEL = object() -class ExpiringCache(object): +class ExpiringCache: def __init__( self, cache_name, @@ -190,7 +190,7 @@ class ExpiringCache(object): return False -class _CacheEntry(object): +class _CacheEntry: __slots__ = ["time", "value"] def __init__(self, time, value): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index df4ea5901d..4bc1a67b58 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py
@@ -30,7 +30,7 @@ def enumerate_leaves(node, depth): yield m -class _Node(object): +class _Node: __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] def __init__(self, prev_node, next_node, key, value, callbacks=set()): @@ -41,7 +41,7 @@ class _Node(object): self.callbacks = callbacks -class LruCache(object): +class LruCache: """ Least-recently-used cache. Supports del_multi only if cache_type=TreeCache diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index a6c60888e5..df1a721add 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py
@@ -23,7 +23,7 @@ from synapse.util.caches import register_cache logger = logging.getLogger(__name__) -class ResponseCache(object): +class ResponseCache: """ This caches a deferred response. Until the deferred completes it will be returned from the cache. This means that if the client retries the request diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index ecd9948e79..eb4d98f683 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py
@@ -3,7 +3,7 @@ from typing import Dict SENTINEL = object() -class TreeCache(object): +class TreeCache: """ Tree-based backing store for LruCache. Allows subtrees of data to be deleted efficiently. @@ -89,7 +89,7 @@ def iterate_tree_cache_entry(d): yield d -class _Entry(object): +class _Entry: __slots__ = ["value"] def __init__(self, value): diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 6437aa907e..3e180cafd3 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) SENTINEL = object() -class TTLCache(object): +class TTLCache: """A key/value cache implementation where each entry has its own TTL""" def __init__(self, cache_name, timer=time.time): @@ -154,7 +154,7 @@ class TTLCache(object): @attr.s(frozen=True, slots=True) -class _CacheEntry(object): +class _CacheEntry: """TTLCache entry""" # expiry_time is the first attribute, so that entries are sorted by expiry. diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py new file mode 100644
index 0000000000..23393cf49b --- /dev/null +++ b/synapse/util/daemonize.py
@@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2012, 2013, 2014 Ilya Otyutskiy <ilya.otyutskiy@icloud.com> +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import fcntl +import logging +import os +import signal +import sys + + +def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None: + """daemonize the current process + + This calls fork(), and has the main process exit. When it returns we will be + running in the child process. + """ + + # If pidfile already exists, we should read pid from there; to overwrite it, if + # locking will fail, because locking attempt somehow purges the file contents. + if os.path.isfile(pid_file): + with open(pid_file, "r") as pid_fh: + old_pid = pid_fh.read() + + # Create a lockfile so that only one instance of this daemon is running at any time. + try: + lock_fh = open(pid_file, "w") + except IOError: + print("Unable to create the pidfile.") + sys.exit(1) + + try: + # Try to get an exclusive lock on the file. This will fail if another process + # has the file locked. + fcntl.flock(lock_fh, fcntl.LOCK_EX | fcntl.LOCK_NB) + except IOError: + print("Unable to lock on the pidfile.") + # We need to overwrite the pidfile if we got here. + # + # XXX better to avoid overwriting it, surely. this looks racey as the pid file + # could be created between us trying to read it and us trying to lock it. + with open(pid_file, "w") as pid_fh: + pid_fh.write(old_pid) + sys.exit(1) + + # Fork, creating a new process for the child. + process_id = os.fork() + + if process_id != 0: + # parent process: exit. + + # we use os._exit to avoid running the atexit handlers. In particular, that + # means we don't flush the logs. This is important because if we are using + # a MemoryHandler, we could have logs buffered which are now buffered in both + # the main and the child process, so if we let the main process flush the logs, + # we'll get two copies. + os._exit(0) + + # This is the child process. Continue. + + # Stop listening for signals that the parent process receives. + # This is done by getting a new process id. + # setpgrp() is an alternative to setsid(). + # setsid puts the process in a new parent group and detaches its controlling + # terminal. + + os.setsid() + + # point stdin, stdout, stderr at /dev/null + devnull = "/dev/null" + if hasattr(os, "devnull"): + # Python has set os.devnull on this system, use it instead as it might be + # different than /dev/null. + devnull = os.devnull + + devnull_fd = os.open(devnull, os.O_RDWR) + os.dup2(devnull_fd, 0) + os.dup2(devnull_fd, 1) + os.dup2(devnull_fd, 2) + os.close(devnull_fd) + + # now that we have redirected stderr to /dev/null, any uncaught exceptions will + # get sent to /dev/null, so make sure we log them. + # + # (we don't normally expect reactor.run to raise any exceptions, but this will + # also catch any other uncaught exceptions before we get that far.) + + def excepthook(type_, value, traceback): + logger.critical("Unhanded exception", exc_info=(type_, value, traceback)) + + sys.excepthook = excepthook + + # Set umask to default to safe file permissions when running as a root daemon. 027 + # is an octal number which we are typing as 0o27 for Python3 compatibility. + os.umask(0o27) + + # Change to a known directory. If this isn't done, starting a daemon in a + # subdirectory that needs to be deleted results in "directory busy" errors. + os.chdir(chdir) + + try: + lock_fh.write("%s" % (os.getpid())) + lock_fh.flush() + except IOError: + logger.error("Unable to write pid to the pidfile.") + print("Unable to write pid to the pidfile.") + sys.exit(1) + + # write a log line on SIGTERM. + def sigterm(signum, frame): + logger.warning("Caught signal %s. Stopping daemon." % signum) + sys.exit(0) + + signal.signal(signal.SIGTERM, sigterm) + + # Cleanup pid file at exit. + def exit(): + logger.warning("Stopping daemon.") + os.remove(pid_file) + sys.exit(0) + + atexit.register(exit) + + logger.warning("Starting daemon.") diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 22a857a306..f73e95393c 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py
@@ -16,8 +16,6 @@ import inspect import logging from twisted.internet import defer -from twisted.internet.defer import Deferred, fail, succeed -from twisted.python import failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -29,12 +27,7 @@ def user_left_room(distributor, user, room_id): distributor.fire("user_left_room", user=user, room_id=room_id) -# XXX: this is no longer used. We should probably kill it. -def user_joined_room(distributor, user, room_id): - distributor.fire("user_joined_room", user=user, room_id=room_id) - - -class Distributor(object): +class Distributor: """A central dispatch point for loosely-connected pieces of code to register, observe, and fire signals. @@ -81,29 +74,7 @@ class Distributor(object): run_as_background_process(name, self.signals[name].fire, *args, **kwargs) -def maybeAwaitableDeferred(f, *args, **kw): - """ - Invoke a function that may or may not return a Deferred or an Awaitable. - - This is a modified version of twisted.internet.defer.maybeDeferred. - """ - try: - result = f(*args, **kw) - except Exception: - return fail(failure.Failure(captureVars=Deferred.debug)) - - if isinstance(result, Deferred): - return result - # Handle the additional case of an awaitable being returned. - elif inspect.isawaitable(result): - return defer.ensureDeferred(result) - elif isinstance(result, failure.Failure): - return fail(result) - else: - return succeed(result) - - -class Signal(object): +class Signal: """A Signal is a dispatch point that stores a list of callables as observers of it. @@ -132,22 +103,17 @@ class Signal(object): Returns a Deferred that will complete when all the observers have completed.""" - def do(observer): - def eb(failure): + async def do(observer): + try: + result = observer(*args, **kwargs) + if inspect.isawaitable(result): + result = await result + return result + except Exception as e: logger.warning( - "%s signal observer %s failed: %r", - self.name, - observer, - failure, - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject(), - ), + "%s signal observer %s failed: %r", self.name, observer, e, ) - return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb) - deferreds = [run_in_background(do, o) for o in self.observers] return make_deferred_yieldable( diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 6a3f6177b1..733f5e26e6 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py
@@ -20,7 +20,7 @@ from twisted.internet import threads from synapse.logging.context import make_deferred_yieldable, run_in_background -class BackgroundFileConsumer(object): +class BackgroundFileConsumer: """A consumer that writes to a file like object. Supports both push and pull producers diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index eab78dd256..bf094c9386 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py
@@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import json +import json + from frozendict import frozendict @@ -63,5 +64,8 @@ def _handle_frozendict(obj): ) -# A JSONEncoder which is capable of encoding frozendicts without barfing -frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) +# A JSONEncoder which is capable of encoding frozendicts without barfing. +# Additionally reduce the whitespace produced by JSON encoding. +frozendict_json_encoder = json.JSONEncoder( + allow_nan=False, separators=(",", ":"), default=_handle_frozendict, +) diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 6dce03dd3a..50516926f3 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py
@@ -14,7 +14,7 @@ # limitations under the License. -class JsonEncodedObject(object): +class JsonEncodedObject: """ A common base class for defining protocol units that are represented as JSON. diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 631654f297..da24ba0470 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py
@@ -94,7 +94,7 @@ class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" def connectionMade(self): - super(SynapseManhole, self).connectionMade() + super().connectionMade() # replace the manhole interpreter with our own impl self.interpreter = SynapseManholeInterpreter(self, self.namespace) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ec61e14423..ffdea0de8d 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py
@@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging from functools import wraps +from typing import Any, Callable, Optional, TypeVar, cast from prometheus_client import Counter -from twisted.internet import defer - -from synapse.logging.context import LoggingContext, current_context +from synapse.logging.context import ( + ContextResourceUsage, + LoggingContext, + current_context, +) from synapse.metrics import InFlightGauge logger = logging.getLogger(__name__) @@ -60,34 +62,42 @@ in_flight = InFlightGauge( sub_metrics=["real_time_max", "real_time_sum"], ) +T = TypeVar("T", bound=Callable[..., Any]) -def measure_func(name=None): - def wrapper(func): - block_name = func.__name__ if name is None else name - if inspect.iscoroutinefunction(func): +def measure_func(name: Optional[str] = None) -> Callable[[T], T]: + """ + Used to decorate an async function with a `Measure` context manager. + + Usage: + + @measure_func() + async def foo(...): + ... - @wraps(func) - async def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = await func(self, *args, **kwargs) - return r + Which is analogous to: - else: + async def foo(...): + with Measure(...): + ... - @wraps(func) - @defer.inlineCallbacks - def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = yield func(self, *args, **kwargs) - return r + """ - return measured_func + def wrapper(func: T) -> T: + block_name = func.__name__ if name is None else name + + @wraps(func) + async def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = await func(self, *args, **kwargs) + return r + + return cast(T, measured_func) return wrapper -class Measure(object): +class Measure: __slots__ = [ "clock", "name", @@ -98,27 +108,27 @@ class Measure(object): def __init__(self, clock, name): self.clock = clock self.name = name - self._logging_context = None + parent_context = current_context() + self._logging_context = LoggingContext( + "Measure[%s]" % (self.name,), parent_context + ) self.start = None - def __enter__(self): - if self._logging_context: + def __enter__(self) -> "Measure": + if self.start is not None: raise RuntimeError("Measure() objects cannot be re-used") self.start = self.clock.time() - parent_context = current_context() - self._logging_context = LoggingContext( - "Measure[%s]" % (self.name,), parent_context - ) self._logging_context.__enter__() in_flight.register((self.name,), self._update_in_flight) + return self def __exit__(self, exc_type, exc_val, exc_tb): - if not self._logging_context: + if self.start is None: raise RuntimeError("Measure() block exited without being entered") duration = self.clock.time() - self.start - usage = self._logging_context.get_resource_usage() + usage = self.get_resource_usage() in_flight.unregister((self.name,), self._update_in_flight) self._logging_context.__exit__(exc_type, exc_val, exc_tb) @@ -134,6 +144,13 @@ class Measure(object): except ValueError: logger.warning("Failed to save metrics! Usage: %s", usage) + def get_resource_usage(self) -> ContextResourceUsage: + """Get the resources used within this Measure block + + If the Measure block is still active, returns the resource usage so far. + """ + return self._logging_context.get_resource_usage() + def _update_in_flight(self, metrics): """Gets called when processing in flight metrics """ diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 54c046b6e1..72574d3af2 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py
@@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - import functools import sys from typing import Any, Callable, List diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index e5efdfcd02..70d11e1ec3 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py
@@ -29,7 +29,7 @@ from synapse.logging.context import ( logger = logging.getLogger(__name__) -class FederationRateLimiter(object): +class FederationRateLimiter: def __init__(self, clock, config): """ Args: @@ -60,7 +60,7 @@ class FederationRateLimiter(object): return self.ratelimiters[host].ratelimit() -class _PerHostRatelimiter(object): +class _PerHostRatelimiter: def __init__(self, clock, config): """ Args: diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 8794317caa..a5cc9d0551 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py
@@ -15,8 +15,6 @@ import logging import random -from twisted.internet import defer - import synapse.logging.context from synapse.api.errors import CodeMessageException @@ -47,15 +45,14 @@ class NotRetryingDestination(Exception): """ msg = "Not retrying server %s." % (destination,) - super(NotRetryingDestination, self).__init__(msg) + super().__init__(msg) self.retry_last_ts = retry_last_ts self.retry_interval = retry_interval self.destination = destination -@defer.inlineCallbacks -def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): +async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): """For a given destination check if we have previously failed to send a request there and are waiting before retrying the destination. If we are not ready to retry the destination, this will raise a @@ -73,9 +70,9 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) Example usage: try: - limiter = yield get_retry_limiter(destination, clock, store) + limiter = await get_retry_limiter(destination, clock, store) with limiter: - response = yield do_request() + response = await do_request() except NotRetryingDestination: # We aren't ready to retry that destination. raise @@ -83,7 +80,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) failure_ts = None retry_last_ts, retry_interval = (0, 0) - retry_timings = yield store.get_destination_retry_timings(destination) + retry_timings = await store.get_destination_retry_timings(destination) if retry_timings: failure_ts = retry_timings["failure_ts"] @@ -117,7 +114,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) ) -class RetryDestinationLimiter(object): +class RetryDestinationLimiter: def __init__( self, destination, @@ -222,10 +219,9 @@ class RetryDestinationLimiter(object): if self.failure_ts is None: self.failure_ts = retry_last_ts - @defer.inlineCallbacks - def store_retry_timings(): + async def store_retry_timings(): try: - yield self.store.set_destination_retry_timings( + await self.store.set_destination_retry_timings( self.destination, self.failure_ts, retry_last_ts, diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 2e2b40a426..61d96a6c28 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py
@@ -24,9 +24,7 @@ from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken -# Note: The : character is allowed here for older clients, but will be removed in a -# future release. Context: https://github.com/matrix-org/synapse/issues/6766 -client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$") +client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 023beb5ede..be3b22469d 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py
@@ -14,7 +14,7 @@ # limitations under the License. -class _Entry(object): +class _Entry: __slots__ = ["end_key", "queue"] def __init__(self, end_key): @@ -22,7 +22,7 @@ class _Entry(object): self.queue = [] -class WheelTimer(object): +class WheelTimer: """Stores arbitrary objects that will be returned after their timers have expired. """