diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/__init__.py | 2 | ||||
-rw-r--r-- | synapse/util/async_helpers.py | 135 | ||||
-rw-r--r-- | synapse/util/distributor.py | 50 | ||||
-rw-r--r-- | synapse/util/frozenutils.py | 5 |
4 files changed, 97 insertions, 95 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index a13f11f8d8..60ecc498ab 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import re import attr -from canonicaljson import json from twisted.internet import defer, task diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index bb57e27beb..67ce9a5f39 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -17,13 +17,25 @@ import collections import logging from contextlib import contextmanager -from typing import Dict, Sequence, Set, Union +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Optional, + Set, + TypeVar, + Union, +) import attr from typing_extensions import ContextManager from twisted.internet import defer from twisted.internet.defer import CancelledError +from twisted.internet.interfaces import IReactorTime from twisted.python import failure from synapse.logging.context import ( @@ -54,7 +66,7 @@ class ObservableDeferred: __slots__ = ["_deferred", "_observers", "_result"] - def __init__(self, deferred, consumeErrors=False): + def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) object.__setattr__(self, "_observers", set()) @@ -111,25 +123,25 @@ class ObservableDeferred: success, res = self._result return defer.succeed(res) if success else defer.fail(res) - def observers(self): + def observers(self) -> List[defer.Deferred]: return self._observers - def has_called(self): + def has_called(self) -> bool: return self._result is not None - def has_succeeded(self): + def has_succeeded(self) -> bool: return self._result is not None and self._result[0] is True - def get_result(self): + def get_result(self) -> Any: return self._result[1] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self._deferred, name) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: setattr(self._deferred, name, value) - def __repr__(self): + def __repr__(self) -> str: return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( id(self), self._result, @@ -137,18 +149,20 @@ class ObservableDeferred: ) -def concurrently_execute(func, args, limit): - """Executes the function with each argument conncurrently while limiting +def concurrently_execute( + func: Callable, args: Iterable[Any], limit: int +) -> defer.Deferred: + """Executes the function with each argument concurrently while limiting the number of concurrent executions. Args: - func (func): Function to execute, should return a deferred or coroutine. - args (Iterable): List of arguments to pass to func, each invocation of func + func: Function to execute, should return a deferred or coroutine. + args: List of arguments to pass to func, each invocation of func gets a single argument. - limit (int): Maximum number of conccurent executions. + limit: Maximum number of conccurent executions. Returns: - deferred: Resolved when all function invocations have finished. + Deferred[list]: Resolved when all function invocations have finished. """ it = iter(args) @@ -167,14 +181,17 @@ def concurrently_execute(func, args, limit): ).addErrback(unwrapFirstError) -def yieldable_gather_results(func, iter, *args, **kwargs): +def yieldable_gather_results( + func: Callable, iter: Iterable, *args: Any, **kwargs: Any +) -> defer.Deferred: """Executes the function with each argument concurrently. Args: - func (func): Function to execute that returns a Deferred - iter (iter): An iterable that yields items that get passed as the first + func: Function to execute that returns a Deferred + iter: An iterable that yields items that get passed as the first argument to the function *args: Arguments to be passed to each call to func + **kwargs: Keyword arguments to be passed to each call to func Returns Deferred[list]: Resolved when all functions have been invoked, or errors if @@ -188,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs): ).addErrback(unwrapFirstError) +@attr.s(slots=True) +class _LinearizerEntry: + # The number of things executing. + count = attr.ib(type=int) + # Deferreds for the things blocked from executing. + deferreds = attr.ib(type=collections.OrderedDict) + + class Linearizer: """Limits concurrent access to resources based on a key. Useful to ensure only a few things happen at a time on a given resource. Example: - with (yield limiter.queue("test_key")): + with await limiter.queue("test_key"): # do some work. """ - def __init__(self, name=None, max_count=1, clock=None): + def __init__( + self, + name: Optional[str] = None, + max_count: int = 1, + clock: Optional[Clock] = None, + ): """ Args: - max_count(int): The maximum number of concurrent accesses + max_count: The maximum number of concurrent accesses """ if name is None: - self.name = id(self) + self.name = id(self) # type: Union[str, int] else: self.name = name @@ -216,15 +246,10 @@ class Linearizer: self._clock = clock self.max_count = max_count - # key_to_defer is a map from the key to a 2 element list where - # the first element is the number of things executing, and - # the second element is an OrderedDict, where the keys are deferreds for the - # things blocked from executing. - self.key_to_defer = ( - {} - ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]] + # key_to_defer is a map from the key to a _LinearizerEntry. + self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry] - def is_queued(self, key) -> bool: + def is_queued(self, key: Hashable) -> bool: """Checks whether there is a process queued up waiting """ entry = self.key_to_defer.get(key) @@ -234,25 +259,27 @@ class Linearizer: # There are waiting deferreds only in the OrderedDict of deferreds is # non-empty. - return bool(entry[1]) + return bool(entry.deferreds) - def queue(self, key): + def queue(self, key: Hashable) -> defer.Deferred: # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not # propagated inside inlineCallbacks until Twisted 18.7) - entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) + entry = self.key_to_defer.setdefault( + key, _LinearizerEntry(0, collections.OrderedDict()) + ) # If the number of things executing is greater than the maximum # then add a deferred to the list of blocked items # When one of the things currently executing finishes it will callback # this item so that it can continue executing. - if entry[0] >= self.max_count: + if entry.count >= self.max_count: res = self._await_lock(key) else: logger.debug( "Acquired uncontended linearizer lock %r for key %r", self.name, key ) - entry[0] += 1 + entry.count += 1 res = defer.succeed(None) # once we successfully get the lock, we need to return a context manager which @@ -267,15 +294,15 @@ class Linearizer: # We've finished executing so check if there are any things # blocked waiting to execute and start one of them - entry[0] -= 1 + entry.count -= 1 - if entry[1]: - (next_def, _) = entry[1].popitem(last=False) + if entry.deferreds: + (next_def, _) = entry.deferreds.popitem(last=False) # we need to run the next thing in the sentinel context. with PreserveLoggingContext(): next_def.callback(None) - elif entry[0] == 0: + elif entry.count == 0: # We were the last thing for this key: remove it from the # map. del self.key_to_defer[key] @@ -283,7 +310,7 @@ class Linearizer: res.addCallback(_ctx_manager) return res - def _await_lock(self, key): + def _await_lock(self, key: Hashable) -> defer.Deferred: """Helper for queue: adds a deferred to the queue Assumes that we've already checked that we've reached the limit of the number @@ -298,11 +325,11 @@ class Linearizer: logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) new_defer = make_deferred_yieldable(defer.Deferred()) - entry[1][new_defer] = 1 + entry.deferreds[new_defer] = 1 def cb(_r): logger.debug("Acquired linearizer lock %r for key %r", self.name, key) - entry[0] += 1 + entry.count += 1 # if the code holding the lock completes synchronously, then it # will recursively run the next claimant on the list. That can @@ -331,7 +358,7 @@ class Linearizer: ) # we just have to take ourselves back out of the queue. - del entry[1][new_defer] + del entry.deferreds[new_defer] return e new_defer.addCallbacks(cb, eb) @@ -419,14 +446,22 @@ class ReadWriteLock: return _ctx_manager() -def _cancelled_to_timed_out_error(value, timeout): +R = TypeVar("R") + + +def _cancelled_to_timed_out_error(value: R, timeout: float) -> R: if isinstance(value, failure.Failure): value.trap(CancelledError) raise defer.TimeoutError(timeout, "Deferred") return value -def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): +def timeout_deferred( + deferred: defer.Deferred, + timeout: float, + reactor: IReactorTime, + on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None, +) -> defer.Deferred: """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new deferred that wraps and times out the given deferred, correctly handling @@ -437,10 +472,10 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred Args: - deferred (Deferred) - timeout (float): Timeout in seconds - reactor (twisted.interfaces.IReactorTime): The twisted reactor to use - on_timeout_cancel (callable): A callable which is called immediately + deferred: The Deferred to potentially timeout. + timeout: Timeout in seconds + reactor: The twisted reactor to use + on_timeout_cancel: A callable which is called immediately after the deferred times out, and not if this deferred is otherwise cancelled before the timeout. @@ -452,7 +487,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): CancelledError Failure into a defer.TimeoutError. Returns: - Deferred + A new Deferred. """ new_d = defer.Deferred() diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index a750261e77..f73e95393c 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -16,8 +16,6 @@ import inspect import logging from twisted.internet import defer -from twisted.internet.defer import Deferred, fail, succeed -from twisted.python import failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -29,11 +27,6 @@ def user_left_room(distributor, user, room_id): distributor.fire("user_left_room", user=user, room_id=room_id) -# XXX: this is no longer used. We should probably kill it. -def user_joined_room(distributor, user, room_id): - distributor.fire("user_joined_room", user=user, room_id=room_id) - - class Distributor: """A central dispatch point for loosely-connected pieces of code to register, observe, and fire signals. @@ -81,28 +74,6 @@ class Distributor: run_as_background_process(name, self.signals[name].fire, *args, **kwargs) -def maybeAwaitableDeferred(f, *args, **kw): - """ - Invoke a function that may or may not return a Deferred or an Awaitable. - - This is a modified version of twisted.internet.defer.maybeDeferred. - """ - try: - result = f(*args, **kw) - except Exception: - return fail(failure.Failure(captureVars=Deferred.debug)) - - if isinstance(result, Deferred): - return result - # Handle the additional case of an awaitable being returned. - elif inspect.isawaitable(result): - return defer.ensureDeferred(result) - elif isinstance(result, failure.Failure): - return fail(result) - else: - return succeed(result) - - class Signal: """A Signal is a dispatch point that stores a list of callables as observers of it. @@ -132,22 +103,17 @@ class Signal: Returns a Deferred that will complete when all the observers have completed.""" - def do(observer): - def eb(failure): + async def do(observer): + try: + result = observer(*args, **kwargs) + if inspect.isawaitable(result): + result = await result + return result + except Exception as e: logger.warning( - "%s signal observer %s failed: %r", - self.name, - observer, - failure, - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject(), - ), + "%s signal observer %s failed: %r", self.name, observer, e, ) - return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb) - deferreds = [run_in_background(do, o) for o in self.observers] return make_deferred_yieldable( diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 0e445e01d7..bf094c9386 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import json +import json + from frozendict import frozendict @@ -66,5 +67,5 @@ def _handle_frozendict(obj): # A JSONEncoder which is capable of encoding frozendicts without barfing. # Additionally reduce the whitespace produced by JSON encoding. frozendict_json_encoder = json.JSONEncoder( - default=_handle_frozendict, separators=(",", ":"), + allow_nan=False, separators=(",", ":"), default=_handle_frozendict, ) |