From 524b8ead778e51adfd6667a33f2700f8e071c256 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Fri, 10 Sep 2021 17:03:18 +0100 Subject: Add types to synapse.util. (#10601) --- synapse/util/__init__.py | 40 ++++++++++------- synapse/util/async_helpers.py | 16 ++++--- synapse/util/batching_queue.py | 2 +- synapse/util/caches/__init__.py | 14 +++--- synapse/util/caches/deferred_cache.py | 14 +++--- synapse/util/caches/dictionary_cache.py | 24 ++++++----- synapse/util/caches/lrucache.py | 5 ++- synapse/util/caches/stream_change_cache.py | 2 +- synapse/util/caches/treecache.py | 16 +++---- synapse/util/daemonize.py | 2 +- synapse/util/distributor.py | 23 +++++----- synapse/util/file_consumer.py | 48 +++++++++++++-------- synapse/util/frozenutils.py | 5 ++- synapse/util/httpresourcetree.py | 27 ++++++------ synapse/util/linked_list.py | 8 ++-- synapse/util/macaroons.py | 2 +- synapse/util/manhole.py | 52 ++++++++++++++-------- synapse/util/patch_inline_callbacks.py | 4 +- synapse/util/ratelimitutils.py | 57 +++++++++++++----------- synapse/util/retryutils.py | 69 ++++++++++++++++++------------ synapse/util/rlimit.py | 2 +- synapse/util/templates.py | 8 ++-- synapse/util/threepids.py | 12 ++++-- synapse/util/versionstring.py | 2 +- synapse/util/wheel_timer.py | 35 ++++++++------- 25 files changed, 281 insertions(+), 208 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index b69f562ca5..bd234549bd 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -15,27 +15,35 @@ import json import logging import re -from typing import Pattern +import typing +from typing import Any, Callable, Dict, Generator, Pattern import attr from frozendict import frozendict from twisted.internet import defer, task +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IDelayedCall, IReactorTime +from twisted.internet.task import LoopingCall +from twisted.python.failure import Failure from synapse.logging import context +if typing.TYPE_CHECKING: + pass + logger = logging.getLogger(__name__) _WILDCARD_RUN = re.compile(r"([\?\*]+)") -def _reject_invalid_json(val): +def _reject_invalid_json(val: Any) -> None: """Do not allow Infinity, -Infinity, or NaN values in JSON.""" raise ValueError("Invalid JSON value: '%s'" % val) -def _handle_frozendict(obj): +def _handle_frozendict(obj: Any) -> Dict[Any, Any]: """Helper for json_encoder. Makes frozendicts serializable by returning the underlying dict """ @@ -60,10 +68,10 @@ json_encoder = json.JSONEncoder( json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) -def unwrapFirstError(failure): +def unwrapFirstError(failure: Failure) -> Failure: # defer.gatherResults and DeferredLists wrap failures. failure.trap(defer.FirstError) - return failure.value.subFailure + return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations @attr.s(slots=True) @@ -75,25 +83,25 @@ class Clock: reactor: The Twisted reactor to use. """ - _reactor = attr.ib() + _reactor: IReactorTime = attr.ib() - @defer.inlineCallbacks - def sleep(self, seconds): - d = defer.Deferred() + @defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations + def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]": + d: defer.Deferred[float] = defer.Deferred() with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) res = yield d return res - def time(self): + def time(self) -> float: """Returns the current system time in seconds since epoch.""" return self._reactor.seconds() - def time_msec(self): + def time_msec(self) -> int: """Returns the current system time in milliseconds since epoch.""" return int(self.time() * 1000) - def looping_call(self, f, msec, *args, **kwargs): + def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall: """Call a function repeatedly. Waits `msec` initially before calling `f` for the first time. @@ -102,8 +110,8 @@ class Clock: other than trivial, you probably want to wrap it in run_as_background_process. Args: - f(function): The function to call repeatedly. - msec(float): How long to wait between calls in milliseconds. + f: The function to call repeatedly. + msec: How long to wait between calls in milliseconds. *args: Postional arguments to pass to function. **kwargs: Key arguments to pass to function. """ @@ -113,7 +121,7 @@ class Clock: d.addErrback(log_failure, "Looping call died", consumeErrors=False) return call - def call_later(self, delay, callback, *args, **kwargs): + def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall: """Call something later Note that the function will be called with no logcontext, so if it is anything @@ -133,7 +141,7 @@ class Clock: with context.PreserveLoggingContext(): return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) - def cancel_call_later(self, timer, ignore_errs=False): + def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None: try: timer.cancel() except Exception: diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a3b65aee27..82d918a05f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -37,6 +37,7 @@ import attr from typing_extensions import ContextManager from twisted.internet import defer +from twisted.internet.base import ReactorBase from twisted.internet.defer import CancelledError from twisted.internet.interfaces import IReactorTime from twisted.python import failure @@ -268,6 +269,7 @@ class Linearizer: if not clock: from twisted.internet import reactor + assert isinstance(reactor, ReactorBase) clock = Clock(reactor) self._clock = clock self.max_count = max_count @@ -411,7 +413,7 @@ class ReadWriteLock: # writers and readers have been resolved. The new writer replaces the latest # writer. - def __init__(self): + def __init__(self) -> None: # Latest readers queued self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {} @@ -503,7 +505,7 @@ def timeout_deferred( timed_out = [False] - def time_it_out(): + def time_it_out() -> None: timed_out[0] = True try: @@ -550,19 +552,21 @@ def timeout_deferred( return new_d +# This class can't be generic because it uses slots with attrs. +# See: https://github.com/python-attrs/attrs/issues/313 @attr.s(slots=True, frozen=True) -class DoneAwaitable: +class DoneAwaitable: # should be: Generic[R] """Simple awaitable that returns the provided value.""" - value = attr.ib() + value = attr.ib(type=Any) # should be: R def __await__(self): return self - def __iter__(self): + def __iter__(self) -> "DoneAwaitable": return self - def __next__(self): + def __next__(self) -> None: raise StopIteration(self.value) diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py index 274cea7eb7..2a903004a9 100644 --- a/synapse/util/batching_queue.py +++ b/synapse/util/batching_queue.py @@ -122,7 +122,7 @@ class BatchingQueue(Generic[V, R]): # First we create a defer and add it and the value to the list of # pending items. - d = defer.Deferred() + d: defer.Deferred[R] = defer.Deferred() self._next_values.setdefault(key, []).append((value, d)) # If we're not currently processing the key fire off a background diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 9012034b7a..cab1bf0c15 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -64,32 +64,32 @@ class CacheMetric: evicted_size = attr.ib(default=0) memory_usage = attr.ib(default=None) - def inc_hits(self): + def inc_hits(self) -> None: self.hits += 1 - def inc_misses(self): + def inc_misses(self) -> None: self.misses += 1 - def inc_evictions(self, size=1): + def inc_evictions(self, size: int = 1) -> None: self.evicted_size += size - def inc_memory_usage(self, memory: int): + def inc_memory_usage(self, memory: int) -> None: if self.memory_usage is None: self.memory_usage = 0 self.memory_usage += memory - def dec_memory_usage(self, memory: int): + def dec_memory_usage(self, memory: int) -> None: self.memory_usage -= memory - def clear_memory_usage(self): + def clear_memory_usage(self) -> None: if self.memory_usage is not None: self.memory_usage = 0 def describe(self): return [] - def collect(self): + def collect(self) -> None: try: if self._cache_type == "response_cache": response_cache_size.labels(self._cache_name).set(len(self._cache)) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index b6456392cd..f05590da0d 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -93,7 +93,7 @@ class DeferredCache(Generic[KT, VT]): TreeCache, "MutableMapping[KT, CacheEntry]" ] = cache_type() - def metrics_cb(): + def metrics_cb() -> None: cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) # cache is used for completed results and maps to the result itself, rather than @@ -113,7 +113,7 @@ class DeferredCache(Generic[KT, VT]): def max_entries(self): return self.cache.max_size - def check_thread(self): + def check_thread(self) -> None: expected_thread = self.thread if expected_thread is None: self.thread = threading.current_thread() @@ -235,7 +235,7 @@ class DeferredCache(Generic[KT, VT]): self._pending_deferred_cache[key] = entry - def compare_and_pop(): + def compare_and_pop() -> bool: """Check if our entry is still the one in _pending_deferred_cache, and if so, pop it. @@ -256,7 +256,7 @@ class DeferredCache(Generic[KT, VT]): return False - def cb(result): + def cb(result) -> None: if compare_and_pop(): self.cache.set(key, result, entry.callbacks) else: @@ -268,7 +268,7 @@ class DeferredCache(Generic[KT, VT]): # not have been. Either way, let's double-check now. entry.invalidate() - def eb(_fail): + def eb(_fail) -> None: compare_and_pop() entry.invalidate() @@ -314,7 +314,7 @@ class DeferredCache(Generic[KT, VT]): for entry in iterate_tree_cache_entry(entry): entry.invalidate() - def invalidate_all(self): + def invalidate_all(self) -> None: self.check_thread() self.cache.clear() for entry in self._pending_deferred_cache.values(): @@ -332,7 +332,7 @@ class CacheEntry: self.callbacks = set(callbacks) self.invalidated = False - def invalidate(self): + def invalidate(self) -> None: if not self.invalidated: self.invalidated = True for callback in self.callbacks: diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 3f852edd7f..ade088aae2 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -27,10 +27,14 @@ logger = logging.getLogger(__name__) KT = TypeVar("KT") # The type of the dictionary keys. DKT = TypeVar("DKT") +# The type of the dictionary values. +DV = TypeVar("DV") +# This class can't be generic because it uses slots with attrs. +# See: https://github.com/python-attrs/attrs/issues/313 @attr.s(slots=True) -class DictionaryEntry: +class DictionaryEntry: # should be: Generic[DKT, DV]. """Returned when getting an entry from the cache Attributes: @@ -43,10 +47,10 @@ class DictionaryEntry: """ full = attr.ib(type=bool) - known_absent = attr.ib() - value = attr.ib() + known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT] + value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV] - def __len__(self): + def __len__(self) -> int: return len(self.value) @@ -56,7 +60,7 @@ class _Sentinel(enum.Enum): sentinel = object() -class DictionaryCache(Generic[KT, DKT]): +class DictionaryCache(Generic[KT, DKT, DV]): """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ @@ -87,7 +91,7 @@ class DictionaryCache(Generic[KT, DKT]): Args: key - dict_key: If given a set of keys then return only those keys + dict_keys: If given a set of keys then return only those keys that exist in the cache. Returns: @@ -125,7 +129,7 @@ class DictionaryCache(Generic[KT, DKT]): self, sequence: int, key: KT, - value: Dict[DKT, Any], + value: Dict[DKT, DV], fetched_keys: Optional[Set[DKT]] = None, ) -> None: """Updates the entry in the cache @@ -151,15 +155,15 @@ class DictionaryCache(Generic[KT, DKT]): self._update_or_insert(key, value, fetched_keys) def _update_or_insert( - self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT] + self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT] ) -> None: # We pop and reinsert as we need to tell the cache the size may have # changed - entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) + entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {})) entry.value.update(value) entry.known_absent.update(known_absent) self.cache[key] = entry - def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None: + def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None: self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 5c65d187b6..39dce9dd41 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -35,6 +35,7 @@ from typing import ( from typing_extensions import Literal from twisted.internet import reactor +from twisted.internet.interfaces import IReactorTime from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -341,7 +342,7 @@ class LruCache(Generic[KT, VT]): # Default `clock` to something sensible. Note that we rename it to # `real_clock` so that mypy doesn't think its still `Optional`. if clock is None: - real_clock = Clock(reactor) + real_clock = Clock(cast(IReactorTime, reactor)) else: real_clock = clock @@ -384,7 +385,7 @@ class LruCache(Generic[KT, VT]): lock = threading.Lock() - def evict(): + def evict() -> None: while cache_len() > self.max_size: # Get the last node in the list (i.e. the oldest node). todelete = list_root.prev_node diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 3a41a8baa6..27b1da235e 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -195,7 +195,7 @@ class StreamChangeCache: for entity in r: del self._entity_to_key[entity] - def _evict(self): + def _evict(self) -> None: while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 4138931e7b..563845f867 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -35,17 +35,17 @@ class TreeCache: root = {key_1: {key_2: _value}} """ - def __init__(self): - self.size = 0 + def __init__(self) -> None: + self.size: int = 0 self.root = TreeCacheNode() - def __setitem__(self, key, value): - return self.set(key, value) + def __setitem__(self, key, value) -> None: + self.set(key, value) - def __contains__(self, key): + def __contains__(self, key) -> bool: return self.get(key, SENTINEL) is not SENTINEL - def set(self, key, value): + def set(self, key, value) -> None: if isinstance(value, TreeCacheNode): # this would mean we couldn't tell where our tree ended and the value # started. @@ -73,7 +73,7 @@ class TreeCache: return default return node.get(key[-1], default) - def clear(self): + def clear(self) -> None: self.size = 0 self.root = TreeCacheNode() @@ -128,7 +128,7 @@ class TreeCache: def values(self): return iterate_tree_cache_entry(self.root) - def __len__(self): + def __len__(self) -> int: return self.size diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index d8532411c2..f1a351cfd4 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -126,7 +126,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - signal.signal(signal.SIGTERM, sigterm) # Cleanup pid file at exit. - def exit(): + def exit() -> None: logger.warning("Stopping daemon.") os.remove(pid_file) sys.exit(0) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 1f803aef6d..31097d6439 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Any, Callable, Dict, List from twisted.internet import defer @@ -37,11 +38,11 @@ class Distributor: model will do for today. """ - def __init__(self): - self.signals = {} - self.pre_registration = {} + def __init__(self) -> None: + self.signals: Dict[str, Signal] = {} + self.pre_registration: Dict[str, List[Callable]] = {} - def declare(self, name): + def declare(self, name: str) -> None: if name in self.signals: raise KeyError("%r already has a signal named %s" % (self, name)) @@ -52,7 +53,7 @@ class Distributor: for observer in self.pre_registration[name]: signal.observe(observer) - def observe(self, name, observer): + def observe(self, name: str, observer: Callable) -> None: if name in self.signals: self.signals[name].observe(observer) else: @@ -62,7 +63,7 @@ class Distributor: self.pre_registration[name] = [] self.pre_registration[name].append(observer) - def fire(self, name, *args, **kwargs): + def fire(self, name: str, *args, **kwargs) -> None: """Dispatches the given signal to the registered observers. Runs the observers as a background process. Does not return a deferred. @@ -83,18 +84,18 @@ class Signal: method into all of the observers. """ - def __init__(self, name): - self.name = name - self.observers = [] + def __init__(self, name: str): + self.name: str = name + self.observers: List[Callable] = [] - def observe(self, observer): + def observe(self, observer: Callable) -> None: """Adds a new callable to the observer list which will be invoked by the 'fire' method. Each observer callable may return a Deferred.""" self.observers.append(observer) - def fire(self, *args, **kwargs): + def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]": """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is not an error to fire a signal with no observers. diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index e946189f9a..de2adacd70 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,10 +13,14 @@ # limitations under the License. import queue +from typing import BinaryIO, Optional, Union, cast from twisted.internet import threads +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IPullProducer, IPushProducer from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import ISynapseReactor class BackgroundFileConsumer: @@ -24,9 +28,9 @@ class BackgroundFileConsumer: and pull producers Args: - file_obj (file): The file like object to write to. Closed when + file_obj: The file like object to write to. Closed when finished. - reactor (twisted.internet.reactor): the Twisted reactor to use + reactor: the Twisted reactor to use """ # For PushProducers pause if we have this many unwritten slices @@ -34,13 +38,13 @@ class BackgroundFileConsumer: # And resume once the size of the queue is less than this _RESUME_ON_QUEUE_SIZE = 2 - def __init__(self, file_obj, reactor): - self._file_obj = file_obj + def __init__(self, file_obj: BinaryIO, reactor: ISynapseReactor) -> None: + self._file_obj: BinaryIO = file_obj - self._reactor = reactor + self._reactor: ISynapseReactor = reactor # Producer we're registered with - self._producer = None + self._producer: Optional[Union[IPushProducer, IPullProducer]] = None # True if PushProducer, false if PullProducer self.streaming = False @@ -51,20 +55,22 @@ class BackgroundFileConsumer: # Queue of slices of bytes to be written. When producer calls # unregister a final None is sent. - self._bytes_queue = queue.Queue() + self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue() # Deferred that is resolved when finished writing - self._finished_deferred = None + self._finished_deferred: Optional[Deferred[None]] = None # If the _writer thread throws an exception it gets stored here. - self._write_exception = None + self._write_exception: Optional[Exception] = None - def registerProducer(self, producer, streaming): + def registerProducer( + self, producer: Union[IPushProducer, IPullProducer], streaming: bool + ) -> None: """Part of IConsumer interface Args: - producer (IProducer) - streaming (bool): True if push based producer, False if pull + producer + streaming: True if push based producer, False if pull based. """ if self._producer: @@ -81,29 +87,33 @@ class BackgroundFileConsumer: if not streaming: self._producer.resumeProducing() - def unregisterProducer(self): + def unregisterProducer(self) -> None: """Part of IProducer interface""" self._producer = None + assert self._finished_deferred is not None if not self._finished_deferred.called: self._bytes_queue.put_nowait(None) - def write(self, bytes): + def write(self, write_bytes: bytes) -> None: """Part of IProducer interface""" if self._write_exception: raise self._write_exception + assert self._finished_deferred is not None if self._finished_deferred.called: raise Exception("consumer has closed") - self._bytes_queue.put_nowait(bytes) + self._bytes_queue.put_nowait(write_bytes) # If this is a PushProducer and the queue is getting behind # then we pause the producer. if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: self._paused_producer = True - self._producer.pauseProducing() + assert self._producer is not None + # cast safe because `streaming` means this is an IPushProducer + cast(IPushProducer, self._producer).pauseProducing() - def _writer(self): + def _writer(self) -> None: """This is run in a background thread to write to the file.""" try: while self._producer or not self._bytes_queue.empty(): @@ -130,11 +140,11 @@ class BackgroundFileConsumer: finally: self._file_obj.close() - def wait(self): + def wait(self) -> "Deferred[None]": """Returns a deferred that resolves when finished writing to file""" return make_deferred_yieldable(self._finished_deferred) - def _resume_paused_producer(self): + def _resume_paused_producer(self) -> None: """Gets called if we should resume producing after being paused""" if self._paused_producer and self._producer: self._paused_producer = False diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 2ac7c2913c..9c405eb4d7 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from frozendict import frozendict -def freeze(o): +def freeze(o: Any) -> Any: if isinstance(o, dict): return frozendict({k: freeze(v) for k, v in o.items()}) @@ -33,7 +34,7 @@ def freeze(o): return o -def unfreeze(o): +def unfreeze(o: Any) -> Any: if isinstance(o, (dict, frozendict)): return {k: unfreeze(v) for k, v in o.items()} diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 3c0e8469f3..b163643ca3 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -13,42 +13,43 @@ # limitations under the License. import logging +from typing import Dict -from twisted.web.resource import NoResource +from twisted.web.resource import NoResource, Resource logger = logging.getLogger(__name__) -def create_resource_tree(desired_tree, root_resource): +def create_resource_tree( + desired_tree: Dict[str, Resource], root_resource: Resource +) -> Resource: """Create the resource tree for this homeserver. This in unduly complicated because Twisted does not support putting child resources more than 1 level deep at a time. Args: - web_client (bool): True to enable the web client. - root_resource (twisted.web.resource.Resource): The root - resource to add the tree to. + desired_tree: Dict from desired paths to desired resources. + root_resource: The root resource to add the tree to. Returns: - twisted.web.resource.Resource: the ``root_resource`` with a tree of - child resources added to it. + The ``root_resource`` with a tree of child resources added to it. """ # ideally we'd just use getChild and putChild but getChild doesn't work # unless you give it a Request object IN ADDITION to the name :/ So # instead, we'll store a copy of this mapping so we can actually add # extra resources to existing nodes. See self._resource_id for the key. - resource_mappings = {} - for full_path, res in desired_tree.items(): + resource_mappings: Dict[str, Resource] = {} + for full_path_str, res in desired_tree.items(): # twisted requires all resources to be bytes - full_path = full_path.encode("utf-8") + full_path = full_path_str.encode("utf-8") logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource for path_seg in full_path.split(b"/")[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" - child_resource = NoResource() + child_resource: Resource = NoResource() last_resource.putChild(path_seg, child_resource) res_id = _resource_id(last_resource, path_seg) resource_mappings[res_id] = child_resource @@ -83,7 +84,7 @@ def create_resource_tree(desired_tree, root_resource): return root_resource -def _resource_id(resource, path_seg): +def _resource_id(resource: Resource, path_seg: bytes) -> str: """Construct an arbitrary resource ID so you can retrieve the mapping later. @@ -96,4 +97,4 @@ def _resource_id(resource, path_seg): Returns: str: A unique string which can be a key to the child Resource. """ - return "%s-%s" % (resource, path_seg) + return "%s-%r" % (resource, path_seg) diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py index a456b136f0..9f4be757ba 100644 --- a/synapse/util/linked_list.py +++ b/synapse/util/linked_list.py @@ -74,7 +74,7 @@ class ListNode(Generic[P]): new_node._refs_insert_after(node) return new_node - def remove_from_list(self): + def remove_from_list(self) -> None: """Remove this node from the list.""" with self._LOCK: self._refs_remove_node_from_list() @@ -84,7 +84,7 @@ class ListNode(Generic[P]): # immediately rather than at the next GC. self.cache_entry = None - def move_after(self, node: "ListNode"): + def move_after(self, node: "ListNode") -> None: """Move this node from its current location in the list to after the given node. """ @@ -103,7 +103,7 @@ class ListNode(Generic[P]): # Insert self back into the list, after target node self._refs_insert_after(node) - def _refs_remove_node_from_list(self): + def _refs_remove_node_from_list(self) -> None: """Internal method to *just* remove the node from the list, without e.g. clearing out the cache entry. """ @@ -122,7 +122,7 @@ class ListNode(Generic[P]): self.prev_node = None self.next_node = None - def _refs_insert_after(self, node: "ListNode"): + def _refs_insert_after(self, node: "ListNode") -> None: """Internal method to insert the node after the given node.""" # This method should only be called when we're not already in the list. diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index d1f76e3dc5..84e4f6ff55 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -77,7 +77,7 @@ def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> N should be considered expired. Normally the current time. """ - def verify_expiry_caveat(caveat: str): + def verify_expiry_caveat(caveat: str) -> bool: time_msec = get_time_ms() prefix = "time < " if not caveat.startswith(prefix): diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index cfb5b94ca9..f8b2d7bea9 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -15,6 +15,7 @@ import inspect import sys import traceback +from typing import Any, Dict, Optional from twisted.conch import manhole_ssh from twisted.conch.insults import insults @@ -22,6 +23,9 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal from twisted.internet import defer +from twisted.internet.protocol import Factory + +from synapse.config.server import ManholeConfig PUBLIC_KEY = ( "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" @@ -61,22 +65,22 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs= -----END RSA PRIVATE KEY-----""" -def manhole(settings, globals): +def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory: """Starts a ssh listener with password authentication using the given username and password. Clients connecting to the ssh listener will find themselves in a colored python shell with the supplied globals. Args: - username(str): The username ssh clients should auth with. - password(str): The password ssh clients should auth with. - globals(dict): The variables to expose in the shell. + username: The username ssh clients should auth with. + password: The password ssh clients should auth with. + globals: The variables to expose in the shell. Returns: - twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` + A factory to pass to ``listenTCP`` """ username = settings.username - password = settings.password + password = settings.password.encode("ascii") priv_key = settings.priv_key if priv_key is None: priv_key = Key.fromString(PRIVATE_KEY) @@ -84,19 +88,22 @@ def manhole(settings, globals): if pub_key is None: pub_key = Key.fromString(PUBLIC_KEY) - if not isinstance(password, bytes): - password = password.encode("ascii") - checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) rlm = manhole_ssh.TerminalRealm() - rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( + # mypy ignored here because: + # - can't deduce types of lambdas + # - variable is Type[ServerProtocol], expr is Callable[[], ServerProtocol] + rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( # type: ignore[misc,assignment] SynapseManhole, dict(globals, __name__="__console__") ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.privateKeys[b"ssh-rsa"] = priv_key - factory.publicKeys[b"ssh-rsa"] = pub_key + + # conch has the wrong type on these dicts (says bytes to bytes, + # should be bytes to Keys judging by how it's used). + factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment] + factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment] return factory @@ -104,7 +111,7 @@ def manhole(settings, globals): class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" - def connectionMade(self): + def connectionMade(self) -> None: super().connectionMade() # replace the manhole interpreter with our own impl @@ -114,13 +121,14 @@ class SynapseManhole(ColoredManhole): class SynapseManholeInterpreter(ManholeInterpreter): - def showsyntaxerror(self, filename=None): + def showsyntaxerror(self, filename: Optional[str] = None) -> None: """Display the syntax error that just occurred. Overrides the base implementation, ignoring sys.excepthook. We always want any syntax errors to be sent to the terminal, rather than sentry. """ type, value, tb = sys.exc_info() + assert value is not None sys.last_type = type sys.last_value = value sys.last_traceback = tb @@ -138,7 +146,7 @@ class SynapseManholeInterpreter(ManholeInterpreter): lines = traceback.format_exception_only(type, value) self.write("".join(lines)) - def showtraceback(self): + def showtraceback(self) -> None: """Display the exception that just occurred. Overrides the base implementation, ignoring sys.excepthook. We always want @@ -146,14 +154,22 @@ class SynapseManholeInterpreter(ManholeInterpreter): """ sys.last_type, sys.last_value, last_tb = ei = sys.exc_info() sys.last_traceback = last_tb + assert last_tb is not None + try: # We remove the first stack item because it is our own code. lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) self.write("".join(lines)) finally: - last_tb = ei = None - - def displayhook(self, obj): + # On the line below, last_tb and ei appear to be dead. + # It's unclear whether there is a reason behind this line. + # It conceivably could be because an exception raised in this block + # will keep the local frame (containing these local variables) around. + # This was adapted taken from CPython's Lib/code.py; see here: + # https://github.com/python/cpython/blob/4dc4300c686f543d504ab6fa9fe600eaf11bb695/Lib/code.py#L131-L150 + last_tb = ei = None # type: ignore + + def displayhook(self, obj: Any) -> None: """ We override the displayhook so that we automatically convert coroutines into Deferreds. (Our superclass' displayhook will take care of the rest, diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 99f01e325c..9dd010af3b 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -24,7 +24,7 @@ from twisted.python.failure import Failure _already_patched = False -def do_patch(): +def do_patch() -> None: """ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit """ @@ -107,7 +107,7 @@ def do_patch(): _already_patched = True -def _check_yield_points(f: Callable, changes: List[str]): +def _check_yield_points(f: Callable, changes: List[str]) -> Callable: """Wraps a generator that is about to be passed to defer.inlineCallbacks checking that after every yield the log contexts are correct. diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index a654c69684..dfe628c97e 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -15,33 +15,36 @@ import collections import contextlib import logging +import typing +from typing import Any, DefaultDict, Iterator, List, Set from twisted.internet import defer from synapse.api.errors import LimitExceededError +from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, run_in_background, ) +from synapse.util import Clock + +if typing.TYPE_CHECKING: + from contextlib import _GeneratorContextManager logger = logging.getLogger(__name__) class FederationRateLimiter: - def __init__(self, clock, config): - """ - Args: - clock (Clock) - config (FederationRateLimitConfig) - """ - - def new_limiter(): + def __init__(self, clock: Clock, config: FederationRateLimitConfig): + def new_limiter() -> "_PerHostRatelimiter": return _PerHostRatelimiter(clock=clock, config=config) - self.ratelimiters = collections.defaultdict(new_limiter) + self.ratelimiters: DefaultDict[ + str, "_PerHostRatelimiter" + ] = collections.defaultdict(new_limiter) - def ratelimit(self, host): + def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]": """Used to ratelimit an incoming request from a given host Example usage: @@ -60,11 +63,11 @@ class FederationRateLimiter: class _PerHostRatelimiter: - def __init__(self, clock, config): + def __init__(self, clock: Clock, config: FederationRateLimitConfig): """ Args: - clock (Clock) - config (FederationRateLimitConfig) + clock + config """ self.clock = clock @@ -75,21 +78,23 @@ class _PerHostRatelimiter: self.concurrent_requests = config.concurrent # request_id objects for requests which have been slept - self.sleeping_requests = set() + self.sleeping_requests: Set[object] = set() # map from request_id object to Deferred for requests which are ready # for processing but have been queued - self.ready_request_queue = collections.OrderedDict() + self.ready_request_queue: collections.OrderedDict[ + object, defer.Deferred[None] + ] = collections.OrderedDict() # request id objects for requests which are in progress - self.current_processing = set() + self.current_processing: Set[object] = set() # times at which we have recently (within the last window_size ms) # received requests. - self.request_times = [] + self.request_times: List[int] = [] @contextlib.contextmanager - def ratelimit(self): + def ratelimit(self) -> "Iterator[defer.Deferred[None]]": # `contextlib.contextmanager` takes a generator and turns it into a # context manager. The generator should only yield once with a value # to be returned by manager. @@ -102,7 +107,7 @@ class _PerHostRatelimiter: finally: self._on_exit(request_id) - def _on_enter(self, request_id): + def _on_enter(self, request_id: object) -> "defer.Deferred[None]": time_now = self.clock.time_msec() # remove any entries from request_times which aren't within the window @@ -120,9 +125,9 @@ class _PerHostRatelimiter: self.request_times.append(time_now) - def queue_request(): + def queue_request() -> "defer.Deferred[None]": if len(self.current_processing) >= self.concurrent_requests: - queue_defer = defer.Deferred() + queue_defer: defer.Deferred[None] = defer.Deferred() self.ready_request_queue[request_id] = queue_defer logger.info( "Ratelimiter: queueing request (queue now %i items)", @@ -145,7 +150,7 @@ class _PerHostRatelimiter: self.sleeping_requests.add(request_id) - def on_wait_finished(_): + def on_wait_finished(_: Any) -> "defer.Deferred[None]": logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) self.sleeping_requests.discard(request_id) queue_defer = queue_request() @@ -155,19 +160,19 @@ class _PerHostRatelimiter: else: ret_defer = queue_request() - def on_start(r): + def on_start(r: object) -> object: logger.debug("Ratelimit [%s]: Processing req", id(request_id)) self.current_processing.add(request_id) return r - def on_err(r): + def on_err(r: object) -> object: # XXX: why is this necessary? this is called before we start # processing the request so why would the request be in # current_processing? self.current_processing.discard(request_id) return r - def on_both(r): + def on_both(r: object) -> object: # Ensure that we've properly cleaned up. self.sleeping_requests.discard(request_id) self.ready_request_queue.pop(request_id, None) @@ -177,7 +182,7 @@ class _PerHostRatelimiter: ret_defer.addBoth(on_both) return make_deferred_yieldable(ret_defer) - def _on_exit(self, request_id): + def _on_exit(self, request_id: object) -> None: logger.debug("Ratelimit [%s]: Processed req", id(request_id)) self.current_processing.discard(request_id) try: diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 129b47cd49..648d9a95a7 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -13,9 +13,13 @@ # limitations under the License. import logging import random +from types import TracebackType +from typing import Any, Optional, Type import synapse.logging.context from synapse.api.errors import CodeMessageException +from synapse.storage import DataStore +from synapse.util import Clock logger = logging.getLogger(__name__) @@ -30,17 +34,17 @@ MAX_RETRY_INTERVAL = 2 ** 62 class NotRetryingDestination(Exception): - def __init__(self, retry_last_ts, retry_interval, destination): + def __init__(self, retry_last_ts: int, retry_interval: int, destination: str): """Raised by the limiter (and federation client) to indicate that we are are deliberately not attempting to contact a given server. Args: - retry_last_ts (int): the unix ts in milliseconds of our last attempt + retry_last_ts: the unix ts in milliseconds of our last attempt to contact the server. 0 indicates that the last attempt was successful or that we've never actually attempted to connect. - retry_interval (int): the time in milliseconds to wait until the next + retry_interval: the time in milliseconds to wait until the next attempt. - destination (str): the domain in question + destination: the domain in question """ msg = "Not retrying server %s." % (destination,) @@ -51,7 +55,13 @@ class NotRetryingDestination(Exception): self.destination = destination -async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): +async def get_retry_limiter( + destination: str, + clock: Clock, + store: DataStore, + ignore_backoff: bool = False, + **kwargs: Any, +) -> "RetryDestinationLimiter": """For a given destination check if we have previously failed to send a request there and are waiting before retrying the destination. If we are not ready to retry the destination, this will raise a @@ -60,10 +70,10 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k CodeMessageException with code < 500) Args: - destination (str): name of homeserver - clock (synapse.util.clock): timing source - store (synapse.storage.transactions.TransactionStore): datastore - ignore_backoff (bool): true to ignore the historical backoff data and + destination: name of homeserver + clock: timing source + store: datastore + ignore_backoff: true to ignore the historical backoff data and try the request anyway. We will still reset the retry_interval on success. Example usage: @@ -114,13 +124,13 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k class RetryDestinationLimiter: def __init__( self, - destination, - clock, - store, - failure_ts, - retry_interval, - backoff_on_404=False, - backoff_on_failure=True, + destination: str, + clock: Clock, + store: DataStore, + failure_ts: Optional[int], + retry_interval: int, + backoff_on_404: bool = False, + backoff_on_failure: bool = True, ): """Marks the destination as "down" if an exception is thrown in the context, except for CodeMessageException with code < 500. @@ -128,17 +138,17 @@ class RetryDestinationLimiter: If no exception is raised, marks the destination as "up". Args: - destination (str) - clock (Clock) - store (DataStore) - failure_ts (int|None): when this destination started failing (in ms since + destination + clock + store + failure_ts: when this destination started failing (in ms since the epoch), or zero if the last request was successful - retry_interval (int): The next retry interval taken from the + retry_interval: The next retry interval taken from the database in milliseconds, or zero if the last request was successful. - backoff_on_404 (bool): Back off if we get a 404 + backoff_on_404: Back off if we get a 404 - backoff_on_failure (bool): set to False if we should not increase the + backoff_on_failure: set to False if we should not increase the retry interval on a failure. """ self.clock = clock @@ -150,10 +160,15 @@ class RetryDestinationLimiter: self.backoff_on_404 = backoff_on_404 self.backoff_on_failure = backoff_on_failure - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: valid_err_code = False if exc_type is None: valid_err_code = True @@ -161,7 +176,7 @@ class RetryDestinationLimiter: # avoid treating exceptions which don't derive from Exception as # failures; this is mostly so as not to catch defer._DefGen. valid_err_code = True - elif issubclass(exc_type, CodeMessageException): + elif isinstance(exc_val, CodeMessageException): # Some error codes are perfectly fine for some APIs, whereas other # APIs may expect to never received e.g. a 404. It's important to # handle 404 as some remote servers will return a 404 when the HS @@ -216,7 +231,7 @@ class RetryDestinationLimiter: if self.failure_ts is None: self.failure_ts = retry_last_ts - async def store_retry_timings(): + async def store_retry_timings() -> None: try: await self.store.set_destination_retry_timings( self.destination, diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py index bf812ab516..06651e956d 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py @@ -18,7 +18,7 @@ import resource logger = logging.getLogger("synapse.app.homeserver") -def change_resource_limit(soft_file_no): +def change_resource_limit(soft_file_no: int) -> None: try: soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) diff --git a/synapse/util/templates.py b/synapse/util/templates.py index 38543dd1ea..eb3c8c9370 100644 --- a/synapse/util/templates.py +++ b/synapse/util/templates.py @@ -16,7 +16,7 @@ import time import urllib.parse -from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union import jinja2 @@ -25,9 +25,9 @@ if TYPE_CHECKING: def build_jinja_env( - template_search_directories: Iterable[str], + template_search_directories: Sequence[str], config: "HomeServerConfig", - autoescape: Union[bool, Callable[[str], bool], None] = None, + autoescape: Union[bool, Callable[[Optional[str]], bool], None] = None, ) -> jinja2.Environment: """Set up a Jinja2 environment to load templates from the given search path @@ -110,5 +110,5 @@ def _create_mxc_to_http_filter( return mxc_to_http_filter -def _format_ts_filter(value: int, format: str): +def _format_ts_filter(value: int, format: str) -> str: return time.strftime(format, time.localtime(value / 1000)) diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index a1cf1960b0..baa9190a9a 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -14,6 +14,10 @@ import logging import re +import typing + +if typing.TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -28,13 +32,13 @@ logger = logging.getLogger(__name__) MAX_EMAIL_ADDRESS_LENGTH = 500 -def check_3pid_allowed(hs, medium, address): +def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: """Checks whether a given format of 3PID is allowed to be used on this HS Args: - hs (synapse.server.HomeServer): server - medium (str): 3pid medium - e.g. email, msisdn - address (str): address within that medium (e.g. "wotan@matrix.org") + hs: server + medium: 3pid medium - e.g. email, msisdn + address: address within that medium (e.g. "wotan@matrix.org") msisdns need to first have been canonicalised Returns: bool: whether the 3PID medium/address is allowed to be added to this HS diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index cb08af7385..1c20b24bbe 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -19,7 +19,7 @@ import subprocess logger = logging.getLogger(__name__) -def get_version_string(module): +def get_version_string(module) -> str: """Given a module calculate a git-aware version string for it. If called on a module not in a git checkout will return `__verison__`. diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 61814aff24..e108adc460 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -11,38 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Generic, List, TypeVar +T = TypeVar("T") -class _Entry: + +class _Entry(Generic[T]): __slots__ = ["end_key", "queue"] - def __init__(self, end_key): - self.end_key = end_key - self.queue = [] + def __init__(self, end_key: int) -> None: + self.end_key: int = end_key + self.queue: List[T] = [] -class WheelTimer: +class WheelTimer(Generic[T]): """Stores arbitrary objects that will be returned after their timers have expired. """ - def __init__(self, bucket_size=5000): + def __init__(self, bucket_size: int = 5000) -> None: """ Args: - bucket_size (int): Size of buckets in ms. Corresponds roughly to the + bucket_size: Size of buckets in ms. Corresponds roughly to the accuracy of the timer. """ - self.bucket_size = bucket_size - self.entries = [] - self.current_tick = 0 + self.bucket_size: int = bucket_size + self.entries: List[_Entry[T]] = [] + self.current_tick: int = 0 - def insert(self, now, obj, then): + def insert(self, now: int, obj: T, then: int) -> None: """Inserts object into timer. Args: - now (int): Current time in msec - obj (object): Object to be inserted - then (int): When to return the object strictly after. + now: Current time in msec + obj: Object to be inserted + then: When to return the object strictly after. """ then_key = int(then / self.bucket_size) + 1 @@ -70,7 +73,7 @@ class WheelTimer: self.entries[-1].queue.append(obj) - def fetch(self, now): + def fetch(self, now: int) -> List[T]: """Fetch any objects that have timed out Args: @@ -87,5 +90,5 @@ class WheelTimer: return ret - def __len__(self): + def __len__(self) -> int: return sum(len(entry.queue) for entry in self.entries) -- cgit 1.4.1