summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2022-11-18 17:05:41 -0600
committerEric Eastwood <erice@element.io>2022-11-18 17:05:41 -0600
commit04de9ea73809670f4ce2cf0dc085065a324453e8 (patch)
tree95cb00be91df21d3002651b0835eca8ea88645e5 /synapse/util
parentMerge branch 'madlittlemods/11850-migrate-to-opentelemetry' into madlittlemod... (diff)
parentMerge branch 'develop' into madlittlemods/11850-migrate-to-opentelemetry (diff)
downloadsynapse-madlittlemods/13356-messages-investigation-scratch-v1.tar.xz
Merge branch 'madlittlemods/11850-migrate-to-opentelemetry' into madlittlemods/13356-messages-investigation-scratch-v1 github/madlittlemods/13356-messages-investigation-scratch-v1 madlittlemods/13356-messages-investigation-scratch-v1
Conflicts:
	synapse/handlers/federation.py
	synapse/handlers/relations.py
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py14
-rw-r--r--synapse/util/async_helpers.py3
-rw-r--r--synapse/util/caches/__init__.py2
-rw-r--r--synapse/util/caches/deferred_cache.py6
-rw-r--r--synapse/util/caches/descriptors.py110
-rw-r--r--synapse/util/caches/dictionary_cache.py9
-rw-r--r--synapse/util/caches/expiringcache.py2
-rw-r--r--synapse/util/caches/lrucache.py8
-rw-r--r--synapse/util/caches/stream_change_cache.py11
-rw-r--r--synapse/util/check_dependencies.py17
-rw-r--r--synapse/util/macaroons.py87
-rw-r--r--synapse/util/ratelimitutils.py2
-rw-r--r--synapse/util/retryutils.py2
-rw-r--r--synapse/util/stringutils.py4
-rw-r--r--synapse/util/threepids.py2
-rw-r--r--synapse/util/wheel_timer.py4
16 files changed, 61 insertions, 222 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index a90f08dd4c..7be9d5f113 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -15,7 +15,7 @@
 import json
 import logging
 import typing
-from typing import Any, Callable, Dict, Generator, Optional
+from typing import Any, Callable, Dict, Generator, Optional, Sequence
 
 import attr
 from frozendict import frozendict
@@ -193,3 +193,15 @@ def log_failure(
 # Version string with git info. Computed here once so that we don't invoke git multiple
 # times.
 SYNAPSE_VERSION = get_distribution_version_string("matrix-synapse", __file__)
+
+
+class ExceptionBundle(Exception):
+    # A poor stand-in for something like Python 3.11's ExceptionGroup.
+    # (A backport called `exceptiongroup` exists but seems overkill: we just want a
+    # container type here.)
+    def __init__(self, message: str, exceptions: Sequence[Exception]):
+        parts = [message]
+        for e in exceptions:
+            parts.append(str(e))
+        super().__init__("\n  - ".join(parts))
+        self.exceptions = exceptions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 7f1d41eb3c..d24c4f68c4 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -217,7 +217,8 @@ async def concurrently_execute(
         limit: Maximum number of conccurent executions.
 
     Returns:
-        Deferred: Resolved when all function invocations have finished.
+        None, when all function invocations have finished. The return values
+        from those functions are discarded.
     """
     it = iter(args)
 
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index f7c3a6794e..9387632d0d 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -197,7 +197,7 @@ def register_cache(
         resize_callback: A function which can be called to resize the cache.
 
     Returns:
-        CacheMetric: an object which provides inc_{hits,misses,evictions} methods
+        an object which provides inc_{hits,misses,evictions} methods
     """
     if resizable:
         if not resize_callback:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 6425f851ea..bf7bd351e0 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -153,7 +153,7 @@ class DeferredCache(Generic[KT, VT]):
         Args:
             key:
             callback: Gets called when the entry in the cache is invalidated
-            update_metrics (bool): whether to update the cache hit rate metrics
+            update_metrics: whether to update the cache hit rate metrics
 
         Returns:
             A Deferred which completes with the result. Note that this may later fail
@@ -395,8 +395,8 @@ class DeferredCache(Generic[KT, VT]):
             # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
             # case of a TreeCache, a dict of keys to cache entries. Either way calling
             # iterate_tree_cache_entry on it will do the right thing.
-            for entry in iterate_tree_cache_entry(entry):
-                for cb in entry.get_invalidation_callbacks(key):
+            for iter_entry in iterate_tree_cache_entry(entry):
+                for cb in iter_entry.get_invalidation_callbacks(key):
                     cb()
 
     def invalidate_all(self) -> None:
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 3909f1caea..75428d19ba 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import enum
 import functools
 import inspect
 import logging
@@ -146,109 +145,6 @@ class _CacheDescriptorBase:
         )
 
 
-class _LruCachedFunction(Generic[F]):
-    cache: LruCache[CacheKey, Any]
-    __call__: F
-
-
-def lru_cache(
-    *, max_entries: int = 1000, cache_context: bool = False
-) -> Callable[[F], _LruCachedFunction[F]]:
-    """A method decorator that applies a memoizing cache around the function.
-
-    This is more-or-less a drop-in equivalent to functools.lru_cache, although note
-    that the signature is slightly different.
-
-    The main differences with functools.lru_cache are:
-        (a) the size of the cache can be controlled via the cache_factor mechanism
-        (b) the wrapped function can request a "cache_context" which provides a
-            callback mechanism to indicate that the result is no longer valid
-        (c) prometheus metrics are exposed automatically.
-
-    The function should take zero or more arguments, which are used as the key for the
-    cache. Single-argument functions use that argument as the cache key; otherwise the
-    arguments are built into a tuple.
-
-    Cached functions can be "chained" (i.e. a cached function can call other cached
-    functions and get appropriately invalidated when they called caches are
-    invalidated) by adding a special "cache_context" argument to the function
-    and passing that as a kwarg to all caches called. For example:
-
-        @lru_cache(cache_context=True)
-        def foo(self, key, cache_context):
-            r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
-            r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
-            return r1 + r2
-
-    The wrapped function also has a 'cache' property which offers direct access to the
-    underlying LruCache.
-    """
-
-    def func(orig: F) -> _LruCachedFunction[F]:
-        desc = LruCacheDescriptor(
-            orig,
-            max_entries=max_entries,
-            cache_context=cache_context,
-        )
-        return cast(_LruCachedFunction[F], desc)
-
-    return func
-
-
-class LruCacheDescriptor(_CacheDescriptorBase):
-    """Helper for @lru_cache"""
-
-    class _Sentinel(enum.Enum):
-        sentinel = object()
-
-    def __init__(
-        self,
-        orig: Callable[..., Any],
-        max_entries: int = 1000,
-        cache_context: bool = False,
-    ):
-        super().__init__(
-            orig, num_args=None, uncached_args=None, cache_context=cache_context
-        )
-        self.max_entries = max_entries
-
-    def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
-        cache: LruCache[CacheKey, Any] = LruCache(
-            cache_name=self.name,
-            max_size=self.max_entries,
-        )
-
-        get_cache_key = self.cache_key_builder
-        sentinel = LruCacheDescriptor._Sentinel.sentinel
-
-        @functools.wraps(self.orig)
-        def _wrapped(*args: Any, **kwargs: Any) -> Any:
-            invalidate_callback = kwargs.pop("on_invalidate", None)
-            callbacks = (invalidate_callback,) if invalidate_callback else ()
-
-            cache_key = get_cache_key(args, kwargs)
-
-            ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
-            if ret != sentinel:
-                return ret
-
-            # Add our own `cache_context` to argument list if the wrapped function
-            # has asked for one
-            if self.add_cache_context:
-                kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
-
-            ret2 = self.orig(obj, *args, **kwargs)
-            cache.set(cache_key, ret2, callbacks=callbacks)
-
-            return ret2
-
-        wrapped = cast(CachedFunction, _wrapped)
-        wrapped.cache = cache
-        obj.__dict__[self.name] = wrapped
-
-        return wrapped
-
-
 class DeferredCacheDescriptor(_CacheDescriptorBase):
     """A method decorator that applies a memoizing cache around the function.
 
@@ -431,6 +327,12 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
         cache: DeferredCache[CacheKey, Any] = cached_method.cache
         num_args = cached_method.num_args
 
+        if num_args != self.num_args:
+            raise TypeError(
+                "Number of args (%s) does not match underlying cache_method_name=%s (%s)."
+                % (self.num_args, self.cached_method_name, num_args)
+            )
+
         @functools.wraps(self.orig)
         def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
             # If we're passed a cache_context then we'll want to call its
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index fa91479c97..5eaf70c7ab 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -169,10 +169,11 @@ class DictionaryCache(Generic[KT, DKT, DV]):
                 if it is in the cache.
 
         Returns:
-            DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry`
-            will contain include the keys that are in the cache. If None then
-            will either return the full dict if in the cache, or the empty
-            dict (with `full` set to False) if it isn't.
+            If `dict_keys` is not None then `DictionaryEntry` will contain include
+            the keys that are in the cache.
+
+            If None then will either return the full dict if in the cache, or the
+            empty dict (with `full` set to False) if it isn't.
         """
         if dict_keys is None:
             # The caller wants the full set of dictionary keys for this cache key
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index c6a5d0dfc0..01ad02af67 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -207,7 +207,7 @@ class ExpiringCache(Generic[KT, VT]):
         items from the cache.
 
         Returns:
-            bool: Whether the cache changed size or not.
+            Whether the cache changed size or not.
         """
         new_size = int(self._original_max_size * factor)
         if new_size != self._max_size:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index aa93109d13..dcf0eac3bf 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -389,11 +389,11 @@ class LruCache(Generic[KT, VT]):
             cache_name: The name of this cache, for the prometheus metrics. If unset,
                 no metrics will be reported on this cache.
 
-            cache_type (type):
+            cache_type:
                 type of underlying cache to be used. Typically one of dict
                 or TreeCache.
 
-            size_callback (func(V) -> int | None):
+            size_callback:
 
             metrics_collection_callback:
                 metrics collection callback. This is called early in the metrics
@@ -403,7 +403,7 @@ class LruCache(Generic[KT, VT]):
 
                 Ignored if cache_name is None.
 
-            apply_cache_factor_from_config (bool): If true, `max_size` will be
+            apply_cache_factor_from_config: If true, `max_size` will be
                 multiplied by a cache factor derived from the homeserver config
 
             clock:
@@ -796,7 +796,7 @@ class LruCache(Generic[KT, VT]):
         items from the cache.
 
         Returns:
-            bool: Whether the cache changed size or not.
+            Whether the cache changed size or not.
         """
         if not self.apply_cache_factor_from_config:
             return False
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 330709b8b7..666f4b6895 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -72,7 +72,7 @@ class StreamChangeCache:
         items from the cache.
 
         Returns:
-            bool: Whether the cache changed size or not.
+            Whether the cache changed size or not.
         """
         new_size = math.floor(self._original_max_size * factor)
         if new_size != self._max_size:
@@ -188,14 +188,8 @@ class StreamChangeCache:
         self._entity_to_key[entity] = stream_pos
         self._evict()
 
-        # if the cache is too big, remove entries
-        while len(self._cache) > self._max_size:
-            k, r = self._cache.popitem(0)
-            self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
-            for entity in r:
-                del self._entity_to_key[entity]
-
     def _evict(self) -> None:
+        # if the cache is too big, remove entries
         while len(self._cache) > self._max_size:
             k, r = self._cache.popitem(0)
             self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
@@ -203,7 +197,6 @@ class StreamChangeCache:
                 self._entity_to_key.pop(entity, None)
 
     def get_max_pos_of_last_change(self, entity: EntityType) -> int:
-
         """Returns an upper bound of the stream id of the last change to an
         entity.
         """
diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
index 66f1da7502..3b1e205700 100644
--- a/synapse/util/check_dependencies.py
+++ b/synapse/util/check_dependencies.py
@@ -66,6 +66,21 @@ def _is_dev_dependency(req: Requirement) -> bool:
     )
 
 
+def _should_ignore_runtime_requirement(req: Requirement) -> bool:
+    # This is a build-time dependency. Irritatingly, `poetry build` ignores the
+    # requirements listed in the [build-system] section of pyproject.toml, so in order
+    # to support `poetry install --no-dev` we have to mark it as a runtime dependency.
+    # See discussion on https://github.com/python-poetry/poetry/issues/6154 (it sounds
+    # like the poetry authors don't consider this a bug?)
+    #
+    # In any case, workaround this by ignoring setuptools_rust here. (It might be
+    # slightly cleaner to put `setuptools_rust` in a `build` extra or similar, but for
+    # now let's do something quick and dirty.
+    if req.name == "setuptools_rust":
+        return True
+    return False
+
+
 class Dependency(NamedTuple):
     requirement: Requirement
     must_be_installed: bool
@@ -77,7 +92,7 @@ def _generic_dependencies() -> Iterable[Dependency]:
     assert requirements is not None
     for raw_requirement in requirements:
         req = Requirement(raw_requirement)
-        if _is_dev_dependency(req):
+        if _is_dev_dependency(req) or _should_ignore_runtime_requirement(req):
             continue
 
         # https://packaging.pypa.io/en/latest/markers.html#usage notes that
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index df77edcce2..5df03d3ddc 100644
--- a/synapse/util/macaroons.py
+++ b/synapse/util/macaroons.py
@@ -24,7 +24,7 @@ from typing_extensions import Literal
 
 from synapse.util import Clock, stringutils
 
-MacaroonType = Literal["access", "delete_pusher", "session", "login"]
+MacaroonType = Literal["access", "delete_pusher", "session"]
 
 
 def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
@@ -111,19 +111,6 @@ class OidcSessionData:
     """The session ID of the ongoing UI Auth ("" if this is a login)"""
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class LoginTokenAttributes:
-    """Data we store in a short-term login token"""
-
-    user_id: str
-
-    auth_provider_id: str
-    """The SSO Identity Provider that the user authenticated with, to get this token."""
-
-    auth_provider_session_id: Optional[str]
-    """The session ID advertised by the SSO Identity Provider."""
-
-
 class MacaroonGenerator:
     def __init__(self, clock: Clock, location: str, secret_key: bytes):
         self._clock = clock
@@ -165,35 +152,6 @@ class MacaroonGenerator:
         macaroon.add_first_party_caveat(f"pushkey = {pushkey}")
         return macaroon.serialize()
 
-    def generate_short_term_login_token(
-        self,
-        user_id: str,
-        auth_provider_id: str,
-        auth_provider_session_id: Optional[str] = None,
-        duration_in_ms: int = (2 * 60 * 1000),
-    ) -> str:
-        """Generate a short-term login token used during SSO logins
-
-        Args:
-            user_id: The user for which the token is valid.
-            auth_provider_id: The SSO IdP the user used.
-            auth_provider_session_id: The session ID got during login from the SSO IdP.
-
-        Returns:
-            A signed token valid for using as a ``m.login.token`` token.
-        """
-        now = self._clock.time_msec()
-        expiry = now + duration_in_ms
-        macaroon = self._generate_base_macaroon("login")
-        macaroon.add_first_party_caveat(f"user_id = {user_id}")
-        macaroon.add_first_party_caveat(f"time < {expiry}")
-        macaroon.add_first_party_caveat(f"auth_provider_id = {auth_provider_id}")
-        if auth_provider_session_id is not None:
-            macaroon.add_first_party_caveat(
-                f"auth_provider_session_id = {auth_provider_session_id}"
-            )
-        return macaroon.serialize()
-
     def generate_oidc_session_token(
         self,
         state: str,
@@ -233,49 +191,6 @@ class MacaroonGenerator:
 
         return macaroon.serialize()
 
-    def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
-        """Verify a short-term-login macaroon
-
-        Checks that the given token is a valid, unexpired short-term-login token
-        minted by this server.
-
-        Args:
-            token: The login token to verify.
-
-        Returns:
-            A set of attributes carried by this token, including the
-            ``user_id`` and informations about the SSO IDP used during that
-            login.
-
-        Raises:
-            MacaroonVerificationFailedException if the verification failed
-        """
-        macaroon = pymacaroons.Macaroon.deserialize(token)
-
-        v = self._base_verifier("login")
-        v.satisfy_general(lambda c: c.startswith("user_id = "))
-        v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
-        v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
-        satisfy_expiry(v, self._clock.time_msec)
-        v.verify(macaroon, self._secret_key)
-
-        user_id = get_value_from_macaroon(macaroon, "user_id")
-        auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
-
-        auth_provider_session_id: Optional[str] = None
-        try:
-            auth_provider_session_id = get_value_from_macaroon(
-                macaroon, "auth_provider_session_id"
-            )
-        except MacaroonVerificationFailedException:
-            pass
-
-        return LoginTokenAttributes(
-            user_id=user_id,
-            auth_provider_id=auth_provider_id,
-            auth_provider_session_id=auth_provider_session_id,
-        )
-
     def verify_guest_token(self, token: str) -> str:
         """Verify a guest access token macaroon
 
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 0154f92107..98f05fc792 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -183,7 +183,7 @@ class FederationRateLimiter:
                 # Handle request ...
 
         Args:
-            host (str): Origin of incoming request.
+            host: Origin of incoming request.
 
         Returns:
             context manager which returns a deferred.
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index d0a69ff843..dcc037b982 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -51,7 +51,7 @@ class NotRetryingDestination(Exception):
             destination: the domain in question
         """
 
-        msg = "Not retrying server %s." % (destination,)
+        msg = f"Not retrying server {destination} because we tried it recently retry_last_ts={retry_last_ts} and we won't check for another retry_interval={retry_interval}ms."
         super().__init__(msg)
 
         self.retry_last_ts = retry_last_ts
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 27a363d7e5..4961fe9313 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -86,7 +86,7 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
         ValueError if the server name could not be parsed.
     """
     try:
-        if server_name[-1] == "]":
+        if server_name and server_name[-1] == "]":
             # ipv6 literal, hopefully
             return server_name, None
 
@@ -123,7 +123,7 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]
     # that nobody is sneaking IP literals in that look like hostnames, etc.
 
     # look for ipv6 literals
-    if host[0] == "[":
+    if host and host[0] == "[":
         if host[-1] != "]":
             raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
 
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 1e9c2faa64..54bc7589fd 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -48,7 +48,7 @@ async def check_3pid_allowed(
         registration: whether we want to bind the 3PID as part of registering a new user.
 
     Returns:
-        bool: whether the 3PID medium/address is allowed to be added to this HS
+        whether the 3PID medium/address is allowed to be added to this HS
     """
     if not await hs.get_password_auth_provider().is_3pid_allowed(
         medium, address, registration
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 177e198e7e..b1ec7f4bd8 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -90,10 +90,10 @@ class WheelTimer(Generic[T]):
         """Fetch any objects that have timed out
 
         Args:
-            now (ms): Current time in msec
+            now: Current time in msec
 
         Returns:
-            list: List of objects that have timed out
+            List of objects that have timed out
         """
         now_key = int(now / self.bucket_size)