diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 81df71a0c5..8514a75a1c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -220,7 +220,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable
self.prune_unread_entries = prune_unread_entries
- def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
+ def __get__(
+ self, obj: Optional[Any], owner: Optional[Type]
+ ) -> Callable[..., "defer.Deferred[Any]"]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.name,
max_entries=self.max_entries,
@@ -232,7 +234,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
get_cache_key = self.cache_key_builder
@functools.wraps(self.orig)
- def _wrapped(*args: Any, **kwargs: Any) -> Any:
+ def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]":
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 4938ddf703..a0efb96d3b 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -15,11 +15,13 @@
import heapq
from itertools import islice
from typing import (
+ Callable,
Collection,
Dict,
Generator,
Iterable,
Iterator,
+ List,
Mapping,
Set,
Sized,
@@ -71,6 +73,31 @@ def chunk_seq(iseq: S, maxlen: int) -> Iterator[S]:
return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
+def partition(
+ iterable: Iterable[T], predicate: Callable[[T], bool]
+) -> Tuple[List[T], List[T]]:
+ """
+ Separate a given iterable into two lists based on the result of a predicate function.
+
+ Args:
+ iterable: the iterable to partition (separate)
+ predicate: a function that takes an item from the iterable and returns a boolean
+
+ Returns:
+ A tuple of two lists, the first containing all items for which the predicate
+ returned True, the second containing all items for which the predicate returned
+ False
+ """
+ true_results = []
+ false_results = []
+ for item in iterable:
+ if predicate(item):
+ true_results.append(item)
+ else:
+ false_results.append(item)
+ return true_results, false_results
+
+
def sorted_topologically(
nodes: Iterable[T],
graph: Mapping[T, Collection[T]],
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 5a638c6e9a..e3a54df48b 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -14,17 +14,17 @@
import importlib
import importlib.util
-import itertools
from types import ModuleType
-from typing import Any, Iterable, Tuple, Type
+from typing import Any, Tuple, Type
import jsonschema
from synapse.config._base import ConfigError
from synapse.config._util import json_error_to_config_error
+from synapse.types import StrSequence
-def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
+def load_module(provider: dict, config_path: StrSequence) -> Tuple[Type, Any]:
"""Loads a synapse module with its config
Args:
@@ -39,9 +39,7 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
modulename = provider.get("module")
if not isinstance(modulename, str):
- raise ConfigError(
- "expected a string", path=itertools.chain(config_path, ("module",))
- )
+ raise ConfigError("expected a string", path=tuple(config_path) + ("module",))
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
@@ -55,19 +53,17 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
try:
provider_config = provider_class.parse_config(module_config)
except jsonschema.ValidationError as e:
- raise json_error_to_config_error(
- e, itertools.chain(config_path, ("config",))
- )
+ raise json_error_to_config_error(e, tuple(config_path) + ("config",))
except ConfigError as e:
raise _wrap_config_error(
"Failed to parse config for module %r" % (modulename,),
- prefix=itertools.chain(config_path, ("config",)),
+ prefix=tuple(config_path) + ("config",),
e=e,
)
except Exception as e:
raise ConfigError(
"Failed to parse config for module %r" % (modulename,),
- path=itertools.chain(config_path, ("config",)),
+ path=tuple(config_path) + ("config",),
) from e
else:
provider_config = module_config
@@ -92,9 +88,7 @@ def load_python_module(location: str) -> ModuleType:
return mod
-def _wrap_config_error(
- msg: str, prefix: Iterable[str], e: ConfigError
-) -> "ConfigError":
+def _wrap_config_error(msg: str, prefix: StrSequence, e: ConfigError) -> "ConfigError":
"""Wrap a relative ConfigError with a new path
This is useful when we have a ConfigError with a relative path due to a problem
@@ -102,7 +96,7 @@ def _wrap_config_error(
"""
path = prefix
if e.path:
- path = itertools.chain(prefix, e.path)
+ path = tuple(prefix) + tuple(e.path)
e1 = ConfigError(msg, path)
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index f262bf95a0..2ad55ac13e 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -25,10 +25,12 @@ from typing import (
Iterator,
List,
Mapping,
+ MutableSet,
Optional,
Set,
Tuple,
)
+from weakref import WeakSet
from prometheus_client.core import Counter
from typing_extensions import ContextManager
@@ -86,7 +88,9 @@ queue_wait_timer = Histogram(
)
-_rate_limiter_instances: Set["FederationRateLimiter"] = set()
+# This must be a `WeakSet`, otherwise we indirectly hold on to entire `HomeServer`s
+# during trial test runs and leak a lot of memory.
+_rate_limiter_instances: MutableSet["FederationRateLimiter"] = WeakSet()
# Protects the _rate_limiter_instances set from concurrent access
_rate_limiter_instances_lock = threading.Lock()
|