diff options
Diffstat (limited to 'synapse/util')
27 files changed, 317 insertions, 314 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 0ae7e2ef3b..dcc747cac1 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -40,6 +40,7 @@ class Clock(object): Args: reactor: The Twisted reactor to use. """ + _reactor = attr.ib() @defer.inlineCallbacks @@ -70,9 +71,7 @@ class Clock(object): call = task.LoopingCall(f) call.clock = self._reactor d = call.start(msec / 1000.0, now=False) - d.addErrback( - log_failure, "Looping call died", consumeErrors=False, - ) + d.addErrback(log_failure, "Looping call died", consumeErrors=False) return call def call_later(self, delay, callback, *args, **kwargs): @@ -84,6 +83,7 @@ class Clock(object): *args: Postional arguments to pass to function. **kwargs: Key arguments to pass to function. """ + def wrapped_callback(*args, **kwargs): with PreserveLoggingContext(): callback(*args, **kwargs) @@ -129,12 +129,7 @@ def log_failure(failure, msg, consumeErrors=True): """ logger.error( - msg, - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject() - ) + msg, exc_info=(failure.type, failure.value, failure.getTracebackObject()) ) if not consumeErrors: @@ -152,12 +147,12 @@ def glob_to_regex(glob): Returns: re.RegexObject """ - res = '' + res = "" for c in glob: - if c == '*': - res = res + '.*' - elif c == '?': - res = res + '.' + if c == "*": + res = res + ".*" + elif c == "?": + res = res + "." else: res = res + re.escape(c) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 7253ba120f..7757b8708a 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -95,6 +95,7 @@ class ObservableDeferred(object): def remove(r): self._observers.discard(d) return r + d.addBoth(remove) self._observers.add(d) @@ -123,7 +124,9 @@ class ObservableDeferred(object): def __repr__(self): return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( - id(self), self._result, self._deferred, + id(self), + self._result, + self._deferred, ) @@ -150,10 +153,12 @@ def concurrently_execute(func, args, limit): except StopIteration: pass - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(_concurrently_execute_inner) - for _ in range(limit) - ], consumeErrors=True)).addErrback(unwrapFirstError) + return logcontext.make_deferred_yieldable( + defer.gatherResults( + [run_in_background(_concurrently_execute_inner) for _ in range(limit)], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) def yieldable_gather_results(func, iter, *args, **kwargs): @@ -169,10 +174,12 @@ def yieldable_gather_results(func, iter, *args, **kwargs): Deferred[list]: Resolved when all functions have been invoked, or errors if one of the function calls fails. """ - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(func, item, *args, **kwargs) - for item in iter - ], consumeErrors=True)).addErrback(unwrapFirstError) + return logcontext.make_deferred_yieldable( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) class Linearizer(object): @@ -185,6 +192,7 @@ class Linearizer(object): # do some work. """ + def __init__(self, name=None, max_count=1, clock=None): """ Args: @@ -197,6 +205,7 @@ class Linearizer(object): if not clock: from twisted.internet import reactor + clock = Clock(reactor) self._clock = clock self.max_count = max_count @@ -221,7 +230,7 @@ class Linearizer(object): res = self._await_lock(key) else: logger.debug( - "Acquired uncontended linearizer lock %r for key %r", self.name, key, + "Acquired uncontended linearizer lock %r for key %r", self.name, key ) entry[0] += 1 res = defer.succeed(None) @@ -266,9 +275,7 @@ class Linearizer(object): """ entry = self.key_to_defer[key] - logger.debug( - "Waiting to acquire linearizer lock %r for key %r", self.name, key, - ) + logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) new_defer = make_deferred_yieldable(defer.Deferred()) entry[1][new_defer] = 1 @@ -293,14 +300,14 @@ class Linearizer(object): logger.info("defer %r got err %r", new_defer, e) if isinstance(e, CancelledError): logger.debug( - "Cancelling wait for linearizer lock %r for key %r", - self.name, key, + "Cancelling wait for linearizer lock %r for key %r", self.name, key ) else: logger.warn( "Unexpected exception waiting for linearizer lock %r for key %r", - self.name, key, + self.name, + key, ) # we just have to take ourselves back out of the queue. @@ -438,7 +445,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): try: deferred.cancel() - except: # noqa: E722, if we throw any exception it'll break time outs + except: # noqa: E722, if we throw any exception it'll break time outs logger.exception("Canceller failed during timeout") if not new_d.called: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index f37d5bec08..8271229015 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -104,8 +104,8 @@ def register_cache(cache_type, cache_name, cache): KNOWN_KEYS = { - key: key for key in - ( + key: key + for key in ( "auth_events", "content", "depth", @@ -150,7 +150,7 @@ def intern_dict(dictionary): def _intern_known_values(key, value): - intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",) + intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key") if key in intern_keys: return intern_string(value) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 187510576a..d2f25063aa 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -40,9 +40,7 @@ _CacheSentinel = object() class CacheEntry(object): - __slots__ = [ - "deferred", "callbacks", "invalidated" - ] + __slots__ = ["deferred", "callbacks", "invalidated"] def __init__(self, deferred, callbacks): self.deferred = deferred @@ -73,7 +71,9 @@ class Cache(object): self._pending_deferred_cache = cache_type() self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type, + max_size=max_entries, + keylen=keylen, + cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, ) @@ -133,10 +133,7 @@ class Cache(object): def set(self, key, value, callback=None): callbacks = [callback] if callback else [] self.check_thread() - entry = CacheEntry( - deferred=value, - callbacks=callbacks, - ) + entry = CacheEntry(deferred=value, callbacks=callbacks) existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: @@ -191,9 +188,7 @@ class Cache(object): def invalidate_many(self, key): self.check_thread() if not isinstance(key, tuple): - raise TypeError( - "The cache key must be a tuple not %r" % (type(key),) - ) + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the @@ -244,29 +239,25 @@ class _CacheDescriptorBase(object): raise Exception( "Not enough explicit positional arguments to key off for %r: " "got %i args, but wanted %i. (@cached cannot key off *args or " - "**kwargs)" - % (orig.__name__, len(all_args), num_args) + "**kwargs)" % (orig.__name__, len(all_args), num_args) ) self.num_args = num_args # list of the names of the args used as the cache key - self.arg_names = all_args[1:num_args + 1] + self.arg_names = all_args[1 : num_args + 1] # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value if arg_spec.defaults: - self.arg_defaults = dict(zip( - all_args[-len(arg_spec.defaults):], - arg_spec.defaults - )) + self.arg_defaults = dict( + zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults) + ) else: self.arg_defaults = {} if "cache_context" in self.arg_names: - raise Exception( - "cache_context arg cannot be included among the cache keys" - ) + raise Exception("cache_context arg cannot be included among the cache keys") self.add_cache_context = cache_context @@ -304,12 +295,24 @@ class CacheDescriptor(_CacheDescriptorBase): ``cache_context``) to use as cache keys. Defaults to all named args of the function. """ - def __init__(self, orig, max_entries=1000, num_args=None, tree=False, - inlineCallbacks=False, cache_context=False, iterable=False): + + def __init__( + self, + orig, + max_entries=1000, + num_args=None, + tree=False, + inlineCallbacks=False, + cache_context=False, + iterable=False, + ): super(CacheDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks, - cache_context=cache_context) + orig, + num_args=num_args, + inlineCallbacks=inlineCallbacks, + cache_context=cache_context, + ) max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) @@ -356,7 +359,9 @@ class CacheDescriptor(_CacheDescriptorBase): return args[0] else: return self.arg_defaults[nm] + else: + def get_cache_key(args, kwargs): return tuple(get_cache_key_gen(args, kwargs)) @@ -383,8 +388,7 @@ class CacheDescriptor(_CacheDescriptorBase): except KeyError: ret = defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), - obj, *args, **kwargs + logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs ) def onErr(f): @@ -437,8 +441,9 @@ class CacheListDescriptor(_CacheDescriptorBase): results. """ - def __init__(self, orig, cached_method_name, list_name, num_args=None, - inlineCallbacks=False): + def __init__( + self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False + ): """ Args: orig (function) @@ -451,7 +456,8 @@ class CacheListDescriptor(_CacheDescriptorBase): be wrapped by defer.inlineCallbacks """ super(CacheListDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks) + orig, num_args=num_args, inlineCallbacks=inlineCallbacks + ) self.list_name = list_name @@ -463,7 +469,7 @@ class CacheListDescriptor(_CacheDescriptorBase): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cached_method_name,) + % (self.list_name, cached_method_name) ) def __get__(self, obj, objtype=None): @@ -494,8 +500,10 @@ class CacheListDescriptor(_CacheDescriptorBase): # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: + def arg_to_cache_key(arg): return arg + else: keylist = list(keyargs) @@ -505,8 +513,7 @@ class CacheListDescriptor(_CacheDescriptorBase): for arg in list_args: try: - res = cache.get(arg_to_cache_key(arg), - callback=invalidate_callback) + res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): @@ -554,18 +561,15 @@ class CacheListDescriptor(_CacheDescriptorBase): args_to_call = dict(arg_dict) args_to_call[self.list_name] = list(missing) - cached_defers.append(defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), - **args_to_call - ).addCallbacks(complete_all, errback)) + cached_defers.append( + defer.maybeDeferred( + logcontext.preserve_fn(self.function_to_call), **args_to_call + ).addCallbacks(complete_all, errback) + ) if cached_defers: - d = defer.gatherResults( - cached_defers, - consumeErrors=True, - ).addCallbacks( - lambda _: results, - unwrapFirstError + d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( + lambda _: results, unwrapFirstError ) return logcontext.make_deferred_yieldable(d) else: @@ -586,8 +590,9 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, - iterable=False): +def cached( + max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False +): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -598,8 +603,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, ) -def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False, - cache_context=False, iterable=False): +def cachedInlineCallbacks( + max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False +): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 6c0b5a4094..6834e6f3ae 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -35,6 +35,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va there. value (dict): The full or partial dict value """ + def __len__(self): return len(self.value) @@ -84,13 +85,15 @@ class DictionaryCache(object): self.metrics.inc_hits() if dict_keys is None: - return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value)) + return DictionaryEntry( + entry.full, entry.known_absent, dict(entry.value) + ) else: - return DictionaryEntry(entry.full, entry.known_absent, { - k: entry.value[k] - for k in dict_keys - if k in entry.value - }) + return DictionaryEntry( + entry.full, + entry.known_absent, + {k: entry.value[k] for k in dict_keys if k in entry.value}, + ) self.metrics.inc_misses() return DictionaryEntry(False, set(), {}) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index f369780277..cddf1ed515 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -28,8 +28,15 @@ SENTINEL = object() class ExpiringCache(object): - def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False, iterable=False): + def __init__( + self, + cache_name, + clock, + max_len=0, + expiry_ms=0, + reset_expiry_on_get=False, + iterable=False, + ): """ Args: cache_name (str): Name of this cache, used for logging. @@ -67,8 +74,7 @@ class ExpiringCache(object): def f(): return run_as_background_process( - "prune_cache_%s" % self._cache_name, - self._prune_cache, + "prune_cache_%s" % self._cache_name, self._prune_cache ) self._clock.looping_call(f, self._expiry_ms / 2) @@ -153,7 +159,9 @@ class ExpiringCache(object): logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self) + self._cache_name, + begin_length, + len(self), ) def __len__(self): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index b684f24e7b..1536cb64f3 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,8 +49,15 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None, - evicted_callback=None): + + def __init__( + self, + max_size, + keylen=1, + cache_type=dict, + size_callback=None, + evicted_callback=None, + ): """ Args: max_size (int): @@ -93,9 +100,12 @@ class LruCache(object): cached_cache_len = [0] if size_callback is not None: + def cache_len(): return cached_cache_len[0] + else: + def cache_len(): return len(cache) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index afb03b2e1b..b1da81633c 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -35,12 +35,10 @@ class ResponseCache(object): self.pending_result_cache = {} # Requests that haven't finished yet. self.clock = hs.get_clock() - self.timeout_sec = timeout_ms / 1000. + self.timeout_sec = timeout_ms / 1000.0 self._name = name - self._metrics = register_cache( - "response_cache", name, self - ) + self._metrics = register_cache("response_cache", name, self) def size(self): return len(self.pending_result_cache) @@ -100,8 +98,7 @@ class ResponseCache(object): def remove(r): if self.timeout_sec: self.clock.call_later( - self.timeout_sec, - self.pending_result_cache.pop, key, None, + self.timeout_sec, self.pending_result_cache.pop, key, None ) else: self.pending_result_cache.pop(key, None) @@ -147,14 +144,15 @@ class ResponseCache(object): """ result = self.get(key) if not result: - logger.info("[%s]: no cached result for [%s], calculating new one", - self._name, key) + logger.info( + "[%s]: no cached result for [%s], calculating new one", self._name, key + ) d = run_in_background(callback, *args, **kwargs) result = self.set(key, d) elif not isinstance(result, defer.Deferred) or result.called: - logger.info("[%s]: using completed cached result for [%s]", - self._name, key) + logger.info("[%s]: using completed cached result for [%s]", self._name, key) else: - logger.info("[%s]: using incomplete cached result for [%s]", - self._name, key) + logger.info( + "[%s]: using incomplete cached result for [%s]", self._name, key + ) return make_deferred_yieldable(result) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 625aedc940..235f64049c 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -77,9 +77,8 @@ class StreamChangeCache(object): if stream_pos >= self._earliest_known_stream_pos: changed_entities = { - self._cache[k] for k in self._cache.islice( - start=self._cache.bisect_right(stream_pos), - ) + self._cache[k] + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) } result = changed_entities.intersection(entities) @@ -114,8 +113,10 @@ class StreamChangeCache(object): assert type(stream_pos) is int if stream_pos >= self._earliest_known_stream_pos: - return [self._cache[k] for k in self._cache.islice( - start=self._cache.bisect_right(stream_pos))] + return [ + self._cache[k] + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) + ] else: return None @@ -136,7 +137,7 @@ class StreamChangeCache(object): while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) self._earliest_known_stream_pos = max( - k, self._earliest_known_stream_pos, + k, self._earliest_known_stream_pos ) self._entity_to_key.pop(r, None) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index dd4c9e6067..9a72218d85 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -9,6 +9,7 @@ class TreeCache(object): efficiently. Keys must be tuples. """ + def __init__(self): self.size = 0 self.root = {} diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 5ba1862506..2af8ca43b1 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -155,6 +155,7 @@ class TTLCache(object): @attr.s(frozen=True, slots=True) class _CacheEntry(object): """TTLCache entry""" + # expiry_time is the first attribute, so that entries are sorted by expiry. expiry_time = attr.ib() key = attr.ib() diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index e14c8bdfda..5a79db821c 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -51,9 +51,7 @@ class Distributor(object): if name in self.signals: raise KeyError("%r already has a signal named %s" % (self, name)) - self.signals[name] = Signal( - name, - ) + self.signals[name] = Signal(name) if name in self.pre_registration: signal = self.signals[name] @@ -78,11 +76,7 @@ class Distributor(object): if name not in self.signals: raise KeyError("%r does not have a signal named %s" % (self, name)) - run_as_background_process( - name, - self.signals[name].fire, - *args, **kwargs - ) + run_as_background_process(name, self.signals[name].fire, *args, **kwargs) class Signal(object): @@ -118,22 +112,23 @@ class Signal(object): def eb(failure): logger.warning( "%s signal observer %s failed: %r", - self.name, observer, failure, + self.name, + observer, + failure, exc_info=( failure.type, failure.value, - failure.getTracebackObject())) + failure.getTracebackObject(), + ), + ) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) - deferreds = [ - run_in_background(do, o) - for o in self.observers - ] + deferreds = [run_in_background(do, o) for o in self.observers] - return make_deferred_yieldable(defer.gatherResults( - deferreds, consumeErrors=True, - )) + return make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) def __repr__(self): return "<Signal name=%r>" % (self.name,) diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 014edea971..635b897d6c 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -60,11 +60,10 @@ def _handle_frozendict(obj): # fishing the protected dict out of the object is a bit nasty, # but we don't really want the overhead of copying the dict. return obj._dict - raise TypeError('Object of type %s is not JSON serializable' % - obj.__class__.__name__) + raise TypeError( + "Object of type %s is not JSON serializable" % obj.__class__.__name__ + ) # A JSONEncoder which is capable of encoding frozendics without barfing -frozendict_json_encoder = json.JSONEncoder( - default=_handle_frozendict, -) +frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 2d7ddc1cbe..1a20c596bf 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -45,7 +45,7 @@ def create_resource_tree(desired_tree, root_resource): logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource - for path_seg in full_path.split(b'/')[1:-1]: + for path_seg in full_path.split(b"/")[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" child_resource = NoResource() @@ -60,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource): # =========================== # now attach the actual desired resource - last_path_seg = full_path.split(b'/')[-1] + last_path_seg = full_path.split(b"/")[-1] # if there is already a resource here, thieve its children and # replace it @@ -70,9 +70,7 @@ def create_resource_tree(desired_tree, root_resource): # to be replaced with the desired resource. existing_dummy_resource = resource_mappings[res_id] for child_name in existing_dummy_resource.listNames(): - child_res_id = _resource_id( - existing_dummy_resource, child_name - ) + child_res_id = _resource_id(existing_dummy_resource, child_name) child_resource = resource_mappings[child_res_id] # steal the children res.putChild(child_name, child_resource) diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index d668e5a6b8..6dce03dd3a 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -70,7 +70,8 @@ class JsonEncodedObject(object): dict """ d = { - k: _encode(v) for (k, v) in self.__dict__.items() + k: _encode(v) + for (k, v) in self.__dict__.items() if k in self.valid_keys and k not in self.internal_keys } d.update(self.unrecognized_keys) @@ -78,7 +79,8 @@ class JsonEncodedObject(object): def get_internal_dict(self): d = { - k: _encode(v, internal=True) for (k, v) in self.__dict__.items() + k: _encode(v, internal=True) + for (k, v) in self.__dict__.items() if k in self.valid_keys } d.update(self.unrecognized_keys) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 645d71282e..10022ff620 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -42,6 +42,8 @@ try: def get_thread_resource_usage(): return resource.getrusage(RUSAGE_THREAD) + + except Exception: # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we # won't track resource usage by returning None. @@ -64,8 +66,11 @@ class ContextResourceUsage(object): """ __slots__ = [ - "ru_stime", "ru_utime", - "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec", + "ru_stime", + "ru_utime", + "db_txn_count", + "db_txn_duration_sec", + "db_sched_duration_sec", "evt_db_fetch_count", ] @@ -91,8 +96,8 @@ class ContextResourceUsage(object): return ContextResourceUsage(copy_from=self) def reset(self): - self.ru_stime = 0. - self.ru_utime = 0. + self.ru_stime = 0.0 + self.ru_utime = 0.0 self.db_txn_count = 0 self.db_txn_duration_sec = 0 @@ -100,15 +105,18 @@ class ContextResourceUsage(object): self.evt_db_fetch_count = 0 def __repr__(self): - return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', " - "db_txn_count='%r', db_txn_duration_sec='%r', " - "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % ( - self.ru_stime, - self.ru_utime, - self.db_txn_count, - self.db_txn_duration_sec, - self.db_sched_duration_sec, - self.evt_db_fetch_count,) + return ( + "<ContextResourceUsage ru_stime='%r', ru_utime='%r', " + "db_txn_count='%r', db_txn_duration_sec='%r', " + "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>" + ) % ( + self.ru_stime, + self.ru_utime, + self.db_txn_count, + self.db_txn_duration_sec, + self.db_sched_duration_sec, + self.evt_db_fetch_count, + ) def __iadd__(self, other): """Add another ContextResourceUsage's stats to this one's. @@ -159,11 +167,15 @@ class LoggingContext(object): """ __slots__ = [ - "previous_context", "name", "parent_context", + "previous_context", + "name", + "parent_context", "_resource_usage", "usage_start", - "main_thread", "alive", - "request", "tag", + "main_thread", + "alive", + "request", + "tag", ] thread_local = threading.local() @@ -196,6 +208,7 @@ class LoggingContext(object): def __nonzero__(self): return False + __bool__ = __nonzero__ # python3 sentinel = Sentinel() @@ -261,7 +274,8 @@ class LoggingContext(object): if self.previous_context != old_context: logger.warn( "Expected previous context %r, found %r", - self.previous_context, old_context + self.previous_context, + old_context, ) self.alive = True @@ -285,9 +299,8 @@ class LoggingContext(object): self.alive = False # if we have a parent, pass our CPU usage stats on - if ( - self.parent_context is not None - and hasattr(self.parent_context, '_resource_usage') + if self.parent_context is not None and hasattr( + self.parent_context, "_resource_usage" ): self.parent_context._resource_usage += self._resource_usage @@ -320,9 +333,7 @@ class LoggingContext(object): # When we stop, let's record the cpu used since we started if not self.usage_start: - logger.warning( - "Called stop on logcontext %s without calling start", self, - ) + logger.warning("Called stop on logcontext %s without calling start", self) return utime_delta, stime_delta = self._get_cputime() @@ -407,6 +418,7 @@ class LoggingContextFilter(logging.Filter): **defaults: Default values to avoid formatters complaining about missing fields """ + def __init__(self, **defaults): self.defaults = defaults @@ -442,17 +454,12 @@ class PreserveLoggingContext(object): def __enter__(self): """Captures the current logging context""" - self.current_context = LoggingContext.set_current_context( - self.new_context - ) + self.current_context = LoggingContext.set_current_context(self.new_context) if self.current_context: self.has_parent = self.current_context.previous_context is not None if not self.current_context.alive: - logger.debug( - "Entering dead context: %s", - self.current_context, - ) + logger.debug("Entering dead context: %s", self.current_context) def __exit__(self, type, value, traceback): """Restores the current logging context""" @@ -470,10 +477,7 @@ class PreserveLoggingContext(object): if self.current_context is not LoggingContext.sentinel: if not self.current_context.alive: - logger.debug( - "Restoring dead context: %s", - self.current_context, - ) + logger.debug("Restoring dead context: %s", self.current_context) def nested_logging_context(suffix, parent_context=None): @@ -500,15 +504,16 @@ def nested_logging_context(suffix, parent_context=None): if parent_context is None: parent_context = LoggingContext.current_context() return LoggingContext( - parent_context=parent_context, - request=parent_context.request + "-" + suffix, + parent_context=parent_context, request=parent_context.request + "-" + suffix ) def preserve_fn(f): """Function decorator which wraps the function with run_in_background""" + def g(*args, **kwargs): return run_in_background(f, *args, **kwargs) + return g @@ -528,7 +533,7 @@ def run_in_background(f, *args, **kwargs): current = LoggingContext.current_context() try: res = f(*args, **kwargs) - except: # noqa: E722 + except: # noqa: E722 # the assumption here is that the caller doesn't want to be disturbed # by synchronous exceptions, so let's turn them into Failures. return defer.fail() @@ -665,6 +670,4 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): with LoggingContext(parent_context=logcontext): return f(*args, **kwargs) - return make_deferred_yieldable( - threads.deferToThreadPool(reactor, threadpool, g) - ) + return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g)) diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py index a46bc47ce3..fbf570c756 100644 --- a/synapse/util/logformatter.py +++ b/synapse/util/logformatter.py @@ -29,6 +29,7 @@ class LogFormatter(logging.Formatter): (Normally only stack frames between the point the exception was raised and where it was caught are logged). """ + def __init__(self, *args, **kwargs): super(LogFormatter, self).__init__(*args, **kwargs) @@ -40,7 +41,7 @@ class LogFormatter(logging.Formatter): # check that we actually have an f_back attribute to work around # https://twistedmatrix.com/trac/ticket/9305 - if tb and hasattr(tb.tb_frame, 'f_back'): + if tb and hasattr(tb.tb_frame, "f_back"): sio.write("Capture point (most recent call last):\n") traceback.print_stack(tb.tb_frame.f_back, None, sio) diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index ef31458226..7df0fa6087 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -44,7 +44,7 @@ def _log_debug_as_f(f, msg, msg_args): lineno=lineno, msg=msg, args=msg_args, - exc_info=None + exc_info=None, ) logger.handle(record) @@ -70,20 +70,11 @@ def log_function(f): r = r[:50] + "..." return r - func_args = [ - "%s=%s" % (k, format(v)) for k, v in bound_args.items() - ] + func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()] - msg_args = { - "func_name": func_name, - "args": ", ".join(func_args) - } + msg_args = {"func_name": func_name, "args": ", ".join(func_args)} - _log_debug_as_f( - f, - "Invoked '%(func_name)s' with args: %(args)s", - msg_args - ) + _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args) return f(*args, **kwargs) @@ -103,19 +94,13 @@ def time_function(f): start = time.clock() try: - _log_debug_as_f( - f, - "[FUNC START] {%s-%d}", - (func_name, id), - ) + _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id)) r = f(*args, **kwargs) finally: end = time.clock() _log_debug_as_f( - f, - "[FUNC END] {%s-%d} %.3f sec", - (func_name, id, end - start,), + f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start) ) return r @@ -137,9 +122,8 @@ def trace_function(f): s = inspect.currentframe().f_back to_print = [ - "\t%s:%s %s. Args: args=%s, kwargs=%s" % ( - pathname, linenum, func_name, args, kwargs - ) + "\t%s:%s %s. Args: args=%s, kwargs=%s" + % (pathname, linenum, func_name, args, kwargs) ] while s: if True or s.f_globals["__name__"].startswith("synapse"): @@ -147,9 +131,7 @@ def trace_function(f): args_string = inspect.formatargvalues(*inspect.getargvalues(s)) to_print.append( - "\t%s:%d %s. Args: %s" % ( - filename, lineno, function, args_string - ) + "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string) ) s = s.f_back @@ -163,7 +145,7 @@ def trace_function(f): lineno=lineno, msg=msg, args=None, - exc_info=None + exc_info=None, ) logger.handle(record) @@ -182,13 +164,13 @@ def get_previous_frames(): filename, lineno, function, _, _ = inspect.getframeinfo(s) args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - to_return.append("{{ %s:%d %s - Args: %s }}" % ( - filename, lineno, function, args_string - )) + to_return.append( + "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string) + ) s = s.f_back - return ", ". join(to_return) + return ", ".join(to_return) def get_previous_frame(ignore=[]): @@ -201,7 +183,10 @@ def get_previous_frame(ignore=[]): args_string = inspect.formatargvalues(*inspect.getargvalues(s)) return "{{ %s:%d %s - Args: %s }}" % ( - filename, lineno, function, args_string + filename, + lineno, + function, + args_string, ) s = s.f_back diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 628a2962d9..631654f297 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -74,27 +74,25 @@ def manhole(username, password, globals): twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` """ if not isinstance(password, bytes): - password = password.encode('ascii') + password = password.encode("ascii") - checker = checkers.InMemoryUsernamePasswordDatabaseDontUse( - **{username: password} - ) + checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) rlm = manhole_ssh.TerminalRealm() rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( - SynapseManhole, - dict(globals, __name__="__console__") + SynapseManhole, dict(globals, __name__="__console__") ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY) - factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY) + factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY) + factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY) return factory class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" + def connectionMade(self): super(SynapseManhole, self).connectionMade() @@ -127,7 +125,7 @@ class SynapseManholeInterpreter(ManholeInterpreter): value = SyntaxError(msg, (filename, lineno, offset, line)) sys.last_value = value lines = traceback.format_exception_only(type, value) - self.write(''.join(lines)) + self.write("".join(lines)) def showtraceback(self): """Display the exception that just occurred. @@ -140,6 +138,6 @@ class SynapseManholeInterpreter(ManholeInterpreter): try: # We remove the first stack item because it is our own code. lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) - self.write(''.join(lines)) + self.write("".join(lines)) finally: last_tb = ei = None diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4b4ac5f6c7..01284d3cf8 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -30,25 +30,31 @@ block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"]) block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"]) block_ru_utime = Counter( - "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]) + "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"] +) block_ru_stime = Counter( - "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]) + "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"] +) block_db_txn_count = Counter( - "synapse_util_metrics_block_db_txn_count", "", ["block_name"]) + "synapse_util_metrics_block_db_txn_count", "", ["block_name"] +) # seconds spent waiting for db txns, excluding scheduling time, in this block block_db_txn_duration = Counter( - "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]) + "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"] +) # seconds spent waiting for a db connection, in this block block_db_sched_duration = Counter( - "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]) + "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"] +) # Tracks the number of blocks currently active in_flight = InFlightGauge( - "synapse_util_metrics_block_in_flight", "", + "synapse_util_metrics_block_in_flight", + "", labels=["block_name"], sub_metrics=["real_time_max", "real_time_sum"], ) @@ -62,13 +68,18 @@ def measure_func(name): with Measure(self.clock, name): r = yield func(self, *args, **kwargs) defer.returnValue(r) + return measured_func + return wrapper class Measure(object): __slots__ = [ - "clock", "name", "start_context", "start", + "clock", + "name", + "start_context", + "start", "created_context", "start_usage", ] @@ -108,7 +119,9 @@ class Measure(object): if context != self.start_context: logger.warn( "Context has unexpectedly changed from '%s' to '%s'. (%r)", - self.start_context, context, self.name + self.start_context, + context, + self.name, ) return @@ -126,8 +139,7 @@ class Measure(object): block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: logger.warn( - "Failed to save metrics! OLD: %r, NEW: %r", - self.start_usage, current + "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current ) if self.created_context: diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 4288312b8a..522acd5aa8 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -28,15 +28,13 @@ def load_module(provider): """ # We need to import the module, and then pick the class out of # that, so we split based on the last dot. - module, clz = provider['module'].rsplit(".", 1) + module, clz = provider["module"].rsplit(".", 1) module = importlib.import_module(module) provider_class = getattr(module, clz) try: provider_config = provider_class.parse_config(provider["config"]) except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) + raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e)) return provider_class, provider_config diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py index a6c30e5265..c8bcbe297a 100644 --- a/synapse/util/msisdn.py +++ b/synapse/util/msisdn.py @@ -36,6 +36,6 @@ def phone_number_to_msisdn(country, number): phoneNumber = phonenumbers.parse(number, country) except phonenumbers.NumberParseException: raise SynapseError(400, "Unable to parse phone number") - return phonenumbers.format_number( - phoneNumber, phonenumbers.PhoneNumberFormat.E164 - )[1:] + return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[ + 1: + ] diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index b146d137f4..06defa8199 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -56,11 +56,7 @@ class FederationRateLimiter(object): _PerHostRatelimiter """ return self.ratelimiters.setdefault( - host, - _PerHostRatelimiter( - clock=self.clock, - config=self._config, - ) + host, _PerHostRatelimiter(clock=self.clock, config=self._config) ).ratelimit() @@ -112,8 +108,7 @@ class _PerHostRatelimiter(object): # remove any entries from request_times which aren't within the window self.request_times[:] = [ - r for r in self.request_times - if time_now - r < self.window_size + r for r in self.request_times if time_now - r < self.window_size ] # reject the request if we already have too many queued up (either @@ -121,9 +116,7 @@ class _PerHostRatelimiter(object): queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) if queue_size > self.reject_limit: raise LimitExceededError( - retry_after_ms=int( - self.window_size / self.sleep_limit - ), + retry_after_ms=int(self.window_size / self.sleep_limit) ) self.request_times.append(time_now) @@ -143,22 +136,18 @@ class _PerHostRatelimiter(object): logger.debug( "Ratelimit [%s]: len(self.request_times)=%d", - id(request_id), len(self.request_times), + id(request_id), + len(self.request_times), ) if len(self.request_times) > self.sleep_limit: - logger.debug( - "Ratelimiter: sleeping request for %f sec", self.sleep_sec, - ) + logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec) ret_defer = run_in_background(self.clock.sleep, self.sleep_sec) self.sleeping_requests.add(request_id) def on_wait_finished(_): - logger.debug( - "Ratelimit [%s]: Finished sleeping", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) self.sleeping_requests.discard(request_id) queue_defer = queue_request() return queue_defer @@ -168,10 +157,7 @@ class _PerHostRatelimiter(object): ret_defer = queue_request() def on_start(r): - logger.debug( - "Ratelimit [%s]: Processing req", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Processing req", id(request_id)) self.current_processing.add(request_id) return r @@ -193,10 +179,7 @@ class _PerHostRatelimiter(object): return make_deferred_yieldable(ret_defer) def _on_exit(self, request_id): - logger.debug( - "Ratelimit [%s]: Processed req", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Processed req", id(request_id)) self.current_processing.discard(request_id) try: # start processing the next item on the queue. diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 69dffd8244..982c6d81ca 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -20,9 +20,7 @@ import six from six import PY2, PY3 from six.moves import range -_string_with_symbols = ( - string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" -) +_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure @@ -31,13 +29,11 @@ rand = random.SystemRandom() def random_string(length): - return ''.join(rand.choice(string.ascii_letters) for _ in range(length)) + return "".join(rand.choice(string.ascii_letters) for _ in range(length)) def random_string_with_symbols(length): - return ''.join( - rand.choice(_string_with_symbols) for _ in range(length) - ) + return "".join(rand.choice(_string_with_symbols) for _ in range(length)) def is_ascii(s): @@ -45,7 +41,7 @@ def is_ascii(s): if PY3: if isinstance(s, bytes): try: - s.decode('ascii').encode('ascii') + s.decode("ascii").encode("ascii") except UnicodeDecodeError: return False except UnicodeEncodeError: @@ -104,12 +100,12 @@ def exception_to_unicode(e): # and instead look at what is in the args member. if len(e.args) == 0: - return u"" + return "" elif len(e.args) > 1: return six.text_type(repr(e.args)) msg = e.args[0] if isinstance(msg, bytes): - return msg.decode('utf-8', errors='replace') + return msg.decode("utf-8", errors="replace") else: return msg diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 75efa0117b..3ec1dfb0c2 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -35,11 +35,13 @@ def check_3pid_allowed(hs, medium, address): for constraint in hs.config.allowed_local_3pids: logger.debug( "Checking 3PID %s (%s) against %s (%s)", - address, medium, constraint['pattern'], constraint['medium'], + address, + medium, + constraint["pattern"], + constraint["medium"], ) - if ( - medium == constraint['medium'] and - re.match(constraint['pattern'], address) + if medium == constraint["medium"] and re.match( + constraint["pattern"], address ): return True else: diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index 3baba3225a..a4d9a462f7 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -23,44 +23,53 @@ logger = logging.getLogger(__name__) def get_version_string(module): try: - null = open(os.devnull, 'w') + null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) try: - git_branch = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_branch = ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) git_branch = "b=" + git_branch except subprocess.CalledProcessError: git_branch = "" try: - git_tag = subprocess.check_output( - ['git', 'describe', '--exact-match'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_tag = ( + subprocess.check_output( + ["git", "describe", "--exact-match"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) git_tag = "t=" + git_tag except subprocess.CalledProcessError: git_tag = "" try: - git_commit = subprocess.check_output( - ['git', 'rev-parse', '--short', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_commit = ( + subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) except subprocess.CalledProcessError: git_commit = "" try: dirty_string = "-this_is_a_dirty_checkout" - is_dirty = subprocess.check_output( - ['git', 'describe', '--dirty=' + dirty_string], - stderr=null, - cwd=cwd, - ).strip().decode('ascii').endswith(dirty_string) + is_dirty = ( + subprocess.check_output( + ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + .endswith(dirty_string) + ) git_dirty = "dirty" if is_dirty else "" except subprocess.CalledProcessError: @@ -68,16 +77,10 @@ def get_version_string(module): if git_branch or git_tag or git_commit or git_dirty: git_version = ",".join( - s for s in - (git_branch, git_tag, git_commit, git_dirty,) - if s + s for s in (git_branch, git_tag, git_commit, git_dirty) if s ) - return ( - "%s (%s)" % ( - module.__version__, git_version, - ) - ) + return "%s (%s)" % (module.__version__, git_version) except Exception as e: logger.info("Failed to check for git repository: %s", e) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 7a9e45aca9..9bf6a44f75 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -69,9 +69,7 @@ class WheelTimer(object): # Add empty entries between the end of the current list and when we want # to insert. This ensures there are no gaps. - self.entries.extend( - _Entry(key) for key in range(last_key, then_key + 1) - ) + self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1)) self.entries[-1].queue.append(obj) |