summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async_helpers.py15
-rw-r--r--synapse/util/caches/__init__.py6
-rw-r--r--synapse/util/caches/cached_call.py129
-rw-r--r--synapse/util/caches/descriptors.py17
-rw-r--r--synapse/util/caches/stream_change_cache.py6
-rw-r--r--synapse/util/distributor.py5
-rw-r--r--synapse/util/file_consumer.py15
-rw-r--r--synapse/util/iterutils.py3
-rw-r--r--synapse/util/jsonobject.py6
-rw-r--r--synapse/util/metrics.py3
-rw-r--r--synapse/util/module_loader.py2
-rw-r--r--synapse/util/patch_inline_callbacks.py17
-rw-r--r--synapse/util/stringutils.py33
13 files changed, 193 insertions, 64 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 9a873c8e8e..719e35b78d 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -252,8 +252,7 @@ class Linearizer:
         self.key_to_defer = {}  # type: Dict[Hashable, _LinearizerEntry]
 
     def is_queued(self, key: Hashable) -> bool:
-        """Checks whether there is a process queued up waiting
-        """
+        """Checks whether there is a process queued up waiting"""
         entry = self.key_to_defer.get(key)
         if not entry:
             # No entry so nothing is waiting.
@@ -452,7 +451,9 @@ R = TypeVar("R")
 
 
 def timeout_deferred(
-    deferred: defer.Deferred, timeout: float, reactor: IReactorTime,
+    deferred: defer.Deferred,
+    timeout: float,
+    reactor: IReactorTime,
 ) -> defer.Deferred:
     """The in built twisted `Deferred.addTimeout` fails to time out deferreds
     that have a canceller that throws exceptions. This method creates a new
@@ -497,7 +498,7 @@ def timeout_deferred(
     delayed_call = reactor.callLater(timeout, time_it_out)
 
     def convert_cancelled(value: failure.Failure):
-        # if the orgininal deferred was cancelled, and our timeout has fired, then
+        # if the original deferred was cancelled, and our timeout has fired, then
         # the reason it was cancelled was due to our timeout. Turn the CancelledError
         # into a TimeoutError.
         if timed_out[0] and value.check(CancelledError):
@@ -529,8 +530,7 @@ def timeout_deferred(
 
 @attr.s(slots=True, frozen=True)
 class DoneAwaitable:
-    """Simple awaitable that returns the provided value.
-    """
+    """Simple awaitable that returns the provided value."""
 
     value = attr.ib()
 
@@ -545,8 +545,7 @@ class DoneAwaitable:
 
 
 def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
-    """Convert a value to an awaitable if not already an awaitable.
-    """
+    """Convert a value to an awaitable if not already an awaitable."""
     if inspect.isawaitable(value):
         assert isinstance(value, Awaitable)
         return value
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 89f0b38535..e676c2cac4 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -149,8 +149,7 @@ KNOWN_KEYS = {
 
 
 def intern_string(string):
-    """Takes a (potentially) unicode string and interns it if it's ascii
-    """
+    """Takes a (potentially) unicode string and interns it if it's ascii"""
     if string is None:
         return None
 
@@ -161,8 +160,7 @@ def intern_string(string):
 
 
 def intern_dict(dictionary):
-    """Takes a dictionary and interns well known keys and their values
-    """
+    """Takes a dictionary and interns well known keys and their values"""
     return {
         KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
         for key, value in dictionary.items()
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
new file mode 100644
index 0000000000..3ee0f2317a
--- /dev/null
+++ b/synapse/util/caches/cached_call.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
+
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+from synapse.logging.context import make_deferred_yieldable, run_in_background
+
+TV = TypeVar("TV")
+
+
+class CachedCall(Generic[TV]):
+    """A wrapper for asynchronous calls whose results should be shared
+
+    This is useful for wrapping asynchronous functions, where there might be multiple
+    callers, but we only want to call the underlying function once (and have the result
+    returned to all callers).
+
+    Similar results can be achieved via a lock of some form, but that typically requires
+    more boilerplate (and ends up being less efficient).
+
+    Correctly handles Synapse logcontexts (logs and resource usage for the underlying
+    function are logged against the logcontext which is active when get() is first
+    called).
+
+    Example usage:
+
+        _cached_val = CachedCall(_load_prop)
+
+        async def handle_request() -> X:
+            # We can call this multiple times, but it will result in a single call to
+            # _load_prop().
+            return await _cached_val.get()
+
+        async def _load_prop() -> X:
+            await difficult_operation()
+
+
+    The implementation is deliberately single-shot (ie, once the call is initiated,
+    there is no way to ask for it to be run). This keeps the implementation and
+    semantics simple. If you want to make a new call, simply replace the whole
+    CachedCall object.
+    """
+
+    __slots__ = ["_callable", "_deferred", "_result"]
+
+    def __init__(self, f: Callable[[], Awaitable[TV]]):
+        """
+        Args:
+            f: The underlying function. Only one call to this function will be alive
+                at once (per instance of CachedCall)
+        """
+        self._callable = f  # type: Optional[Callable[[], Awaitable[TV]]]
+        self._deferred = None  # type: Optional[Deferred]
+        self._result = None  # type: Union[None, Failure, TV]
+
+    async def get(self) -> TV:
+        """Kick off the call if necessary, and return the result"""
+
+        # Fire off the callable now if this is our first time
+        if not self._deferred:
+            self._deferred = run_in_background(self._callable)
+
+            # we will never need the callable again, so make sure it can be GCed
+            self._callable = None
+
+            # once the deferred completes, store the result. We cannot simply leave the
+            # result in the deferred, since if it's a Failure, GCing the deferred
+            # would then log a critical error about unhandled Failures.
+            def got_result(r):
+                self._result = r
+
+            self._deferred.addBoth(got_result)
+
+        # TODO: consider cancellation semantics. Currently, if the call to get()
+        #    is cancelled, the underlying call will continue (and any future calls
+        #    will get the result/exception), which I think is *probably* ok, modulo
+        #    the fact the underlying call may be logged to a cancelled logcontext,
+        #    and any eventual exception may not be reported.
+
+        # we can now await the deferred, and once it completes, return the result.
+        await make_deferred_yieldable(self._deferred)
+
+        # I *think* this is the easiest way to correctly raise a Failure without having
+        # to gut-wrench into the implementation of Deferred.
+        d = Deferred()
+        d.callback(self._result)
+        return await d
+
+
+class RetryOnExceptionCachedCall(Generic[TV]):
+    """A wrapper around CachedCall which will retry the call if an exception is thrown
+
+    This is used in much the same way as CachedCall, but adds some extra functionality
+    so that if the underlying function throws an exception, then the next call to get()
+    will initiate another call to the underlying function. (Any calls to get() which
+    are already pending will raise the exception.)
+    """
+
+    slots = ["_cachedcall"]
+
+    def __init__(self, f: Callable[[], Awaitable[TV]]):
+        async def _wrapper() -> TV:
+            try:
+                return await f()
+            except Exception:
+                # the call raised an exception: replace the underlying CachedCall to
+                # trigger another call next time get() is called
+                self._cachedcall = CachedCall(_wrapper)
+                raise
+
+        self._cachedcall = CachedCall(_wrapper)
+
+    async def get(self) -> TV:
+        return await self._cachedcall.get()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index a924140cdf..4e84379914 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -122,7 +122,8 @@ class _LruCachedFunction(Generic[F]):
 
 
 def lru_cache(
-    max_entries: int = 1000, cache_context: bool = False,
+    max_entries: int = 1000,
+    cache_context: bool = False,
 ) -> Callable[[F], _LruCachedFunction[F]]:
     """A method decorator that applies a memoizing cache around the function.
 
@@ -156,7 +157,9 @@ def lru_cache(
 
     def func(orig: F) -> _LruCachedFunction[F]:
         desc = LruCacheDescriptor(
-            orig, max_entries=max_entries, cache_context=cache_context,
+            orig,
+            max_entries=max_entries,
+            cache_context=cache_context,
         )
         return cast(_LruCachedFunction[F], desc)
 
@@ -170,14 +173,18 @@ class LruCacheDescriptor(_CacheDescriptorBase):
         sentinel = object()
 
     def __init__(
-        self, orig, max_entries: int = 1000, cache_context: bool = False,
+        self,
+        orig,
+        max_entries: int = 1000,
+        cache_context: bool = False,
     ):
         super().__init__(orig, num_args=None, cache_context=cache_context)
         self.max_entries = max_entries
 
     def __get__(self, obj, owner):
         cache = LruCache(
-            cache_name=self.orig.__name__, max_size=self.max_entries,
+            cache_name=self.orig.__name__,
+            max_size=self.max_entries,
         )  # type: LruCache[CacheKey, Any]
 
         get_cache_key = self.cache_key_builder
@@ -212,7 +219,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
 
 class DeferredCacheDescriptor(_CacheDescriptorBase):
-    """ A method decorator that applies a memoizing cache around the function.
+    """A method decorator that applies a memoizing cache around the function.
 
     This caches deferreds, rather than the results themselves. Deferreds that
     fail are removed from the cache.
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index c541bf4579..644e9e778a 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -84,8 +84,7 @@ class StreamChangeCache:
         return False
 
     def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
-        """Returns True if the entity may have been updated since stream_pos
-        """
+        """Returns True if the entity may have been updated since stream_pos"""
         assert isinstance(stream_pos, int)
 
         if stream_pos < self._earliest_known_stream_pos:
@@ -133,8 +132,7 @@ class StreamChangeCache:
         return result
 
     def has_any_entity_changed(self, stream_pos: int) -> bool:
-        """Returns if any entity has changed
-        """
+        """Returns if any entity has changed"""
         assert type(stream_pos) is int
 
         if not self._cache:
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index a6ee9edaec..3c47285d05 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -108,7 +108,10 @@ class Signal:
                 return await maybe_awaitable(observer(*args, **kwargs))
             except Exception as e:
                 logger.warning(
-                    "%s signal observer %s failed: %r", self.name, observer, e,
+                    "%s signal observer %s failed: %r",
+                    self.name,
+                    observer,
+                    e,
                 )
 
         deferreds = [run_in_background(do, o) for o in self.observers]
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 733f5e26e6..68dc632491 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -83,15 +83,13 @@ class BackgroundFileConsumer:
             self._producer.resumeProducing()
 
     def unregisterProducer(self):
-        """Part of IProducer interface
-        """
+        """Part of IProducer interface"""
         self._producer = None
         if not self._finished_deferred.called:
             self._bytes_queue.put_nowait(None)
 
     def write(self, bytes):
-        """Part of IProducer interface
-        """
+        """Part of IProducer interface"""
         if self._write_exception:
             raise self._write_exception
 
@@ -107,8 +105,7 @@ class BackgroundFileConsumer:
             self._producer.pauseProducing()
 
     def _writer(self):
-        """This is run in a background thread to write to the file.
-        """
+        """This is run in a background thread to write to the file."""
         try:
             while self._producer or not self._bytes_queue.empty():
                 # If we've paused the producer check if we should resume the
@@ -135,13 +132,11 @@ class BackgroundFileConsumer:
             self._file_obj.close()
 
     def wait(self):
-        """Returns a deferred that resolves when finished writing to file
-        """
+        """Returns a deferred that resolves when finished writing to file"""
         return make_deferred_yieldable(self._finished_deferred)
 
     def _resume_paused_producer(self):
-        """Gets called if we should resume producing after being paused
-        """
+        """Gets called if we should resume producing after being paused"""
         if self._paused_producer and self._producer:
             self._paused_producer = False
             self._producer.resumeProducing()
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 8d2411513f..98707c119d 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -62,7 +62,8 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
 
 
 def sorted_topologically(
-    nodes: Iterable[T], graph: Mapping[T, Collection[T]],
+    nodes: Iterable[T],
+    graph: Mapping[T, Collection[T]],
 ) -> Generator[T, None, None]:
     """Given a set of nodes and a graph, yield the nodes in toplogical order.
 
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 50516926f3..e3a8ed5b2f 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -15,7 +15,7 @@
 
 
 class JsonEncodedObject:
-    """ A common base class for defining protocol units that are represented
+    """A common base class for defining protocol units that are represented
     as JSON.
 
     Attributes:
@@ -39,7 +39,7 @@ class JsonEncodedObject:
     """
 
     def __init__(self, **kwargs):
-        """ Takes the dict of `kwargs` and loads all keys that are *valid*
+        """Takes the dict of `kwargs` and loads all keys that are *valid*
         (i.e., are included in the `valid_keys` list) into the dictionary`
         instance variable.
 
@@ -61,7 +61,7 @@ class JsonEncodedObject:
                 self.unrecognized_keys[k] = v
 
     def get_dict(self):
-        """ Converts this protocol unit into a :py:class:`dict`, ready to be
+        """Converts this protocol unit into a :py:class:`dict`, ready to be
         encoded as JSON.
 
         The keys it encodes are: `valid_keys` - `internal_keys`
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index f4de6b9f54..1023c856d1 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -161,8 +161,7 @@ class Measure:
         return self._logging_context.get_resource_usage()
 
     def _update_in_flight(self, metrics):
-        """Gets called when processing in flight metrics
-        """
+        """Gets called when processing in flight metrics"""
         duration = self.clock.time() - self.start
 
         metrics.real_time_max = max(metrics.real_time_max, duration)
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 09b094ded7..d184e2a90c 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -25,7 +25,7 @@ from synapse.config._util import json_error_to_config_error
 
 
 def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
-    """ Loads a synapse module with its config
+    """Loads a synapse module with its config
 
     Args:
         provider: a dict with keys 'module' (the module name) and 'config'
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 72574d3af2..d9f9ae99d6 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -204,16 +204,13 @@ def _check_yield_points(f: Callable, changes: List[str]):
                 # We don't raise here as its perfectly valid for contexts to
                 # change in a function, as long as it sets the correct context
                 # on resolving (which is checked separately).
-                err = (
-                    "%s changed context from %s to %s, happened between lines %d and %d in %s"
-                    % (
-                        frame.f_code.co_name,
-                        expected_context,
-                        current_context(),
-                        last_yield_line_no,
-                        frame.f_lineno,
-                        frame.f_code.co_filename,
-                    )
+                err = "%s changed context from %s to %s, happened between lines %d and %d in %s" % (
+                    frame.f_code.co_name,
+                    expected_context,
+                    current_context(),
+                    last_yield_line_no,
+                    frame.f_lineno,
+                    frame.f_code.co_filename,
                 )
                 changes.append(err)
 
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index f8038bf861..9ce7873ab5 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError
 _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
 
 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
-client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
+CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
 
 # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
 # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
@@ -42,28 +42,31 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
 rand = random.SystemRandom()
 
 
-def random_string(length):
+def random_string(length: int) -> str:
     return "".join(rand.choice(string.ascii_letters) for _ in range(length))
 
 
-def random_string_with_symbols(length):
+def random_string_with_symbols(length: int) -> str:
     return "".join(rand.choice(_string_with_symbols) for _ in range(length))
 
 
-def is_ascii(s):
-    if isinstance(s, bytes):
-        try:
-            s.decode("ascii").encode("ascii")
-        except UnicodeDecodeError:
-            return False
-        except UnicodeEncodeError:
-            return False
-        return True
+def is_ascii(s: bytes) -> bool:
+    try:
+        s.decode("ascii").encode("ascii")
+    except UnicodeDecodeError:
+        return False
+    except UnicodeEncodeError:
+        return False
+    return True
 
 
-def assert_valid_client_secret(client_secret):
-    """Validate that a given string matches the client_secret regex defined by the spec"""
-    if client_secret_regex.match(client_secret) is None:
+def assert_valid_client_secret(client_secret: str) -> None:
+    """Validate that a given string matches the client_secret defined by the spec"""
+    if (
+        len(client_secret) <= 0
+        or len(client_secret) > 255
+        or CLIENT_SECRET_REGEX.match(client_secret) is None
+    ):
         raise SynapseError(
             400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
         )