diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/async_helpers.py | 15 | ||||
-rw-r--r-- | synapse/util/caches/__init__.py | 6 | ||||
-rw-r--r-- | synapse/util/caches/cached_call.py | 129 | ||||
-rw-r--r-- | synapse/util/caches/descriptors.py | 17 | ||||
-rw-r--r-- | synapse/util/caches/stream_change_cache.py | 6 | ||||
-rw-r--r-- | synapse/util/distributor.py | 5 | ||||
-rw-r--r-- | synapse/util/file_consumer.py | 15 | ||||
-rw-r--r-- | synapse/util/iterutils.py | 3 | ||||
-rw-r--r-- | synapse/util/jsonobject.py | 6 | ||||
-rw-r--r-- | synapse/util/metrics.py | 3 | ||||
-rw-r--r-- | synapse/util/module_loader.py | 5 | ||||
-rw-r--r-- | synapse/util/patch_inline_callbacks.py | 17 | ||||
-rw-r--r-- | synapse/util/stringutils.py | 33 | ||||
-rw-r--r-- | synapse/util/templates.py | 115 |
14 files changed, 310 insertions, 65 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 9a873c8e8e..719e35b78d 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -252,8 +252,7 @@ class Linearizer: self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry] def is_queued(self, key: Hashable) -> bool: - """Checks whether there is a process queued up waiting - """ + """Checks whether there is a process queued up waiting""" entry = self.key_to_defer.get(key) if not entry: # No entry so nothing is waiting. @@ -452,7 +451,9 @@ R = TypeVar("R") def timeout_deferred( - deferred: defer.Deferred, timeout: float, reactor: IReactorTime, + deferred: defer.Deferred, + timeout: float, + reactor: IReactorTime, ) -> defer.Deferred: """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new @@ -497,7 +498,7 @@ def timeout_deferred( delayed_call = reactor.callLater(timeout, time_it_out) def convert_cancelled(value: failure.Failure): - # if the orgininal deferred was cancelled, and our timeout has fired, then + # if the original deferred was cancelled, and our timeout has fired, then # the reason it was cancelled was due to our timeout. Turn the CancelledError # into a TimeoutError. if timed_out[0] and value.check(CancelledError): @@ -529,8 +530,7 @@ def timeout_deferred( @attr.s(slots=True, frozen=True) class DoneAwaitable: - """Simple awaitable that returns the provided value. - """ + """Simple awaitable that returns the provided value.""" value = attr.ib() @@ -545,8 +545,7 @@ class DoneAwaitable: def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: - """Convert a value to an awaitable if not already an awaitable. - """ + """Convert a value to an awaitable if not already an awaitable.""" if inspect.isawaitable(value): assert isinstance(value, Awaitable) return value diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 89f0b38535..e676c2cac4 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -149,8 +149,7 @@ KNOWN_KEYS = { def intern_string(string): - """Takes a (potentially) unicode string and interns it if it's ascii - """ + """Takes a (potentially) unicode string and interns it if it's ascii""" if string is None: return None @@ -161,8 +160,7 @@ def intern_string(string): def intern_dict(dictionary): - """Takes a dictionary and interns well known keys and their values - """ + """Takes a dictionary and interns well known keys and their values""" return { KNOWN_KEYS.get(key, key): _intern_known_values(key, value) for key, value in dictionary.items() diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py new file mode 100644 index 0000000000..3ee0f2317a --- /dev/null +++ b/synapse/util/caches/cached_call.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union + +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure + +from synapse.logging.context import make_deferred_yieldable, run_in_background + +TV = TypeVar("TV") + + +class CachedCall(Generic[TV]): + """A wrapper for asynchronous calls whose results should be shared + + This is useful for wrapping asynchronous functions, where there might be multiple + callers, but we only want to call the underlying function once (and have the result + returned to all callers). + + Similar results can be achieved via a lock of some form, but that typically requires + more boilerplate (and ends up being less efficient). + + Correctly handles Synapse logcontexts (logs and resource usage for the underlying + function are logged against the logcontext which is active when get() is first + called). + + Example usage: + + _cached_val = CachedCall(_load_prop) + + async def handle_request() -> X: + # We can call this multiple times, but it will result in a single call to + # _load_prop(). + return await _cached_val.get() + + async def _load_prop() -> X: + await difficult_operation() + + + The implementation is deliberately single-shot (ie, once the call is initiated, + there is no way to ask for it to be run). This keeps the implementation and + semantics simple. If you want to make a new call, simply replace the whole + CachedCall object. + """ + + __slots__ = ["_callable", "_deferred", "_result"] + + def __init__(self, f: Callable[[], Awaitable[TV]]): + """ + Args: + f: The underlying function. Only one call to this function will be alive + at once (per instance of CachedCall) + """ + self._callable = f # type: Optional[Callable[[], Awaitable[TV]]] + self._deferred = None # type: Optional[Deferred] + self._result = None # type: Union[None, Failure, TV] + + async def get(self) -> TV: + """Kick off the call if necessary, and return the result""" + + # Fire off the callable now if this is our first time + if not self._deferred: + self._deferred = run_in_background(self._callable) + + # we will never need the callable again, so make sure it can be GCed + self._callable = None + + # once the deferred completes, store the result. We cannot simply leave the + # result in the deferred, since if it's a Failure, GCing the deferred + # would then log a critical error about unhandled Failures. + def got_result(r): + self._result = r + + self._deferred.addBoth(got_result) + + # TODO: consider cancellation semantics. Currently, if the call to get() + # is cancelled, the underlying call will continue (and any future calls + # will get the result/exception), which I think is *probably* ok, modulo + # the fact the underlying call may be logged to a cancelled logcontext, + # and any eventual exception may not be reported. + + # we can now await the deferred, and once it completes, return the result. + await make_deferred_yieldable(self._deferred) + + # I *think* this is the easiest way to correctly raise a Failure without having + # to gut-wrench into the implementation of Deferred. + d = Deferred() + d.callback(self._result) + return await d + + +class RetryOnExceptionCachedCall(Generic[TV]): + """A wrapper around CachedCall which will retry the call if an exception is thrown + + This is used in much the same way as CachedCall, but adds some extra functionality + so that if the underlying function throws an exception, then the next call to get() + will initiate another call to the underlying function. (Any calls to get() which + are already pending will raise the exception.) + """ + + slots = ["_cachedcall"] + + def __init__(self, f: Callable[[], Awaitable[TV]]): + async def _wrapper() -> TV: + try: + return await f() + except Exception: + # the call raised an exception: replace the underlying CachedCall to + # trigger another call next time get() is called + self._cachedcall = CachedCall(_wrapper) + raise + + self._cachedcall = CachedCall(_wrapper) + + async def get(self) -> TV: + return await self._cachedcall.get() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index a924140cdf..4e84379914 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -122,7 +122,8 @@ class _LruCachedFunction(Generic[F]): def lru_cache( - max_entries: int = 1000, cache_context: bool = False, + max_entries: int = 1000, + cache_context: bool = False, ) -> Callable[[F], _LruCachedFunction[F]]: """A method decorator that applies a memoizing cache around the function. @@ -156,7 +157,9 @@ def lru_cache( def func(orig: F) -> _LruCachedFunction[F]: desc = LruCacheDescriptor( - orig, max_entries=max_entries, cache_context=cache_context, + orig, + max_entries=max_entries, + cache_context=cache_context, ) return cast(_LruCachedFunction[F], desc) @@ -170,14 +173,18 @@ class LruCacheDescriptor(_CacheDescriptorBase): sentinel = object() def __init__( - self, orig, max_entries: int = 1000, cache_context: bool = False, + self, + orig, + max_entries: int = 1000, + cache_context: bool = False, ): super().__init__(orig, num_args=None, cache_context=cache_context) self.max_entries = max_entries def __get__(self, obj, owner): cache = LruCache( - cache_name=self.orig.__name__, max_size=self.max_entries, + cache_name=self.orig.__name__, + max_size=self.max_entries, ) # type: LruCache[CacheKey, Any] get_cache_key = self.cache_key_builder @@ -212,7 +219,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): class DeferredCacheDescriptor(_CacheDescriptorBase): - """ A method decorator that applies a memoizing cache around the function. + """A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that fail are removed from the cache. diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index c541bf4579..644e9e778a 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -84,8 +84,7 @@ class StreamChangeCache: return False def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: - """Returns True if the entity may have been updated since stream_pos - """ + """Returns True if the entity may have been updated since stream_pos""" assert isinstance(stream_pos, int) if stream_pos < self._earliest_known_stream_pos: @@ -133,8 +132,7 @@ class StreamChangeCache: return result def has_any_entity_changed(self, stream_pos: int) -> bool: - """Returns if any entity has changed - """ + """Returns if any entity has changed""" assert type(stream_pos) is int if not self._cache: diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index a6ee9edaec..3c47285d05 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -108,7 +108,10 @@ class Signal: return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: logger.warning( - "%s signal observer %s failed: %r", self.name, observer, e, + "%s signal observer %s failed: %r", + self.name, + observer, + e, ) deferreds = [run_in_background(do, o) for o in self.observers] diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 733f5e26e6..68dc632491 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -83,15 +83,13 @@ class BackgroundFileConsumer: self._producer.resumeProducing() def unregisterProducer(self): - """Part of IProducer interface - """ + """Part of IProducer interface""" self._producer = None if not self._finished_deferred.called: self._bytes_queue.put_nowait(None) def write(self, bytes): - """Part of IProducer interface - """ + """Part of IProducer interface""" if self._write_exception: raise self._write_exception @@ -107,8 +105,7 @@ class BackgroundFileConsumer: self._producer.pauseProducing() def _writer(self): - """This is run in a background thread to write to the file. - """ + """This is run in a background thread to write to the file.""" try: while self._producer or not self._bytes_queue.empty(): # If we've paused the producer check if we should resume the @@ -135,13 +132,11 @@ class BackgroundFileConsumer: self._file_obj.close() def wait(self): - """Returns a deferred that resolves when finished writing to file - """ + """Returns a deferred that resolves when finished writing to file""" return make_deferred_yieldable(self._finished_deferred) def _resume_paused_producer(self): - """Gets called if we should resume producing after being paused - """ + """Gets called if we should resume producing after being paused""" if self._paused_producer and self._producer: self._paused_producer = False self._producer.resumeProducing() diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 8d2411513f..98707c119d 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -62,7 +62,8 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: def sorted_topologically( - nodes: Iterable[T], graph: Mapping[T, Collection[T]], + nodes: Iterable[T], + graph: Mapping[T, Collection[T]], ) -> Generator[T, None, None]: """Given a set of nodes and a graph, yield the nodes in toplogical order. diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index 50516926f3..e3a8ed5b2f 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -15,7 +15,7 @@ class JsonEncodedObject: - """ A common base class for defining protocol units that are represented + """A common base class for defining protocol units that are represented as JSON. Attributes: @@ -39,7 +39,7 @@ class JsonEncodedObject: """ def __init__(self, **kwargs): - """ Takes the dict of `kwargs` and loads all keys that are *valid* + """Takes the dict of `kwargs` and loads all keys that are *valid* (i.e., are included in the `valid_keys` list) into the dictionary` instance variable. @@ -61,7 +61,7 @@ class JsonEncodedObject: self.unrecognized_keys[k] = v def get_dict(self): - """ Converts this protocol unit into a :py:class:`dict`, ready to be + """Converts this protocol unit into a :py:class:`dict`, ready to be encoded as JSON. The keys it encodes are: `valid_keys` - `internal_keys` diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index f4de6b9f54..1023c856d1 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -161,8 +161,7 @@ class Measure: return self._logging_context.get_resource_usage() def _update_in_flight(self, metrics): - """Gets called when processing in flight metrics - """ + """Gets called when processing in flight metrics""" duration = self.clock.time() - self.start metrics.real_time_max = max(metrics.real_time_max, duration) diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 1ee61851e4..d184e2a90c 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -25,7 +25,7 @@ from synapse.config._util import json_error_to_config_error def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: - """ Loads a synapse module with its config + """Loads a synapse module with its config Args: provider: a dict with keys 'module' (the module name) and 'config' @@ -49,7 +49,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: module = importlib.import_module(module) provider_class = getattr(module, clz) - module_config = provider.get("config") + # Load the module config. If None, pass an empty dictionary instead + module_config = provider.get("config") or {} try: provider_config = provider_class.parse_config(module_config) except jsonschema.ValidationError as e: diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 72574d3af2..d9f9ae99d6 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -204,16 +204,13 @@ def _check_yield_points(f: Callable, changes: List[str]): # We don't raise here as its perfectly valid for contexts to # change in a function, as long as it sets the correct context # on resolving (which is checked separately). - err = ( - "%s changed context from %s to %s, happened between lines %d and %d in %s" - % ( - frame.f_code.co_name, - expected_context, - current_context(), - last_yield_line_no, - frame.f_lineno, - frame.f_code.co_filename, - ) + err = "%s changed context from %s to %s, happened between lines %d and %d in %s" % ( + frame.f_code.co_name, + expected_context, + current_context(), + last_yield_line_no, + frame.f_lineno, + frame.f_code.co_filename, ) changes.append(err) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index f8038bf861..9ce7873ab5 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken -client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$") # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically @@ -42,28 +42,31 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") rand = random.SystemRandom() -def random_string(length): +def random_string(length: int) -> str: return "".join(rand.choice(string.ascii_letters) for _ in range(length)) -def random_string_with_symbols(length): +def random_string_with_symbols(length: int) -> str: return "".join(rand.choice(_string_with_symbols) for _ in range(length)) -def is_ascii(s): - if isinstance(s, bytes): - try: - s.decode("ascii").encode("ascii") - except UnicodeDecodeError: - return False - except UnicodeEncodeError: - return False - return True +def is_ascii(s: bytes) -> bool: + try: + s.decode("ascii").encode("ascii") + except UnicodeDecodeError: + return False + except UnicodeEncodeError: + return False + return True -def assert_valid_client_secret(client_secret): - """Validate that a given string matches the client_secret regex defined by the spec""" - if client_secret_regex.match(client_secret) is None: +def assert_valid_client_secret(client_secret: str) -> None: + """Validate that a given string matches the client_secret defined by the spec""" + if ( + len(client_secret) <= 0 + or len(client_secret) > 255 + or CLIENT_SECRET_REGEX.match(client_secret) is None + ): raise SynapseError( 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM ) diff --git a/synapse/util/templates.py b/synapse/util/templates.py new file mode 100644 index 0000000000..392dae4a40 --- /dev/null +++ b/synapse/util/templates.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for dealing with jinja2 templates""" + +import time +import urllib.parse +from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union + +import jinja2 + +if TYPE_CHECKING: + from synapse.config.homeserver import HomeServerConfig + + +def build_jinja_env( + template_search_directories: Iterable[str], + config: "HomeServerConfig", + autoescape: Union[bool, Callable[[str], bool], None] = None, +) -> jinja2.Environment: + """Set up a Jinja2 environment to load templates from the given search path + + The returned environment defines the following filters: + - format_ts: formats timestamps as strings in the server's local timezone + (XXX: why is that useful??) + - mxc_to_http: converts mxc: uris to http URIs. Args are: + (uri, width, height, resize_method="crop") + + and the following global variables: + - server_name: matrix server name + + Args: + template_search_directories: directories to search for templates + + config: homeserver config, for things like `server_name` and `public_baseurl` + + autoescape: whether template variables should be autoescaped. bool, or + a function mapping from template name to bool. Defaults to escaping templates + whose names end in .html, .xml or .htm. + + Returns: + jinja environment + """ + + if autoescape is None: + autoescape = jinja2.select_autoescape() + + loader = jinja2.FileSystemLoader(template_search_directories) + env = jinja2.Environment(loader=loader, autoescape=autoescape) + + # Update the environment with our custom filters + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl), + } + ) + + # common variables for all templates + env.globals.update({"server_name": config.server_name}) + + return env + + +def _create_mxc_to_http_filter( + public_baseurl: Optional[str], +) -> Callable[[str, int, int, str], str]: + """Create and return a jinja2 filter that converts MXC urls to HTTP + + Args: + public_baseurl: The public, accessible base URL of the homeserver + """ + + def mxc_to_http_filter( + value: str, width: int, height: int, resize_method: str = "crop" + ) -> str: + if not public_baseurl: + raise RuntimeError( + "public_baseurl must be set in the homeserver config to convert MXC URLs to HTTP URLs." + ) + + if value[0:6] != "mxc://": + return "" + + server_and_media_id = value[6:] + fragment = None + if "#" in server_and_media_id: + server_and_media_id, fragment = server_and_media_id.split("#", 1) + fragment = "#" + fragment + + params = {"width": width, "height": height, "method": resize_method} + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + public_baseurl, + server_and_media_id, + urllib.parse.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter + + +def _format_ts_filter(value: int, format: str): + return time.strftime(format, time.localtime(value / 1000)) |