diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 382f0cf3f0..9a873c8e8e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -15,10 +15,12 @@
# limitations under the License.
import collections
+import inspect
import logging
from contextlib import contextmanager
from typing import (
Any,
+ Awaitable,
Callable,
Dict,
Hashable,
@@ -542,11 +544,11 @@ class DoneAwaitable:
raise StopIteration(self.value)
-def maybe_awaitable(value):
+def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
"""Convert a value to an awaitable if not already an awaitable.
"""
-
- if hasattr(value, "__await__"):
+ if inspect.isawaitable(value):
+ assert isinstance(value, Awaitable)
return value
return DoneAwaitable(value)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index f73e95393c..a6ee9edaec 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -105,10 +105,7 @@ class Signal:
async def do(observer):
try:
- result = observer(*args, **kwargs)
- if inspect.isawaitable(result):
- result = await result
- return result
+ return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
logger.warning(
"%s signal observer %s failed: %r", self.name, observer, e,
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 94b59afb38..1ee61851e4 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -15,28 +15,56 @@
import importlib
import importlib.util
+import itertools
+from typing import Any, Iterable, Tuple, Type
+
+import jsonschema
from synapse.config._base import ConfigError
+from synapse.config._util import json_error_to_config_error
-def load_module(provider):
+def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
""" Loads a synapse module with its config
- Take a dict with keys 'module' (the module name) and 'config'
- (the config dict).
+
+ Args:
+ provider: a dict with keys 'module' (the module name) and 'config'
+ (the config dict).
+ config_path: the path within the config file. This will be used as a basis
+ for any error message.
Returns
Tuple of (provider class, parsed config object)
"""
+
+ modulename = provider.get("module")
+ if not isinstance(modulename, str):
+ raise ConfigError(
+ "expected a string", path=itertools.chain(config_path, ("module",))
+ )
+
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
- module, clz = provider["module"].rsplit(".", 1)
+ module, clz = modulename.rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
+ module_config = provider.get("config")
try:
- provider_config = provider_class.parse_config(provider.get("config"))
+ provider_config = provider_class.parse_config(module_config)
+ except jsonschema.ValidationError as e:
+ raise json_error_to_config_error(e, itertools.chain(config_path, ("config",)))
+ except ConfigError as e:
+ raise _wrap_config_error(
+ "Failed to parse config for module %r" % (modulename,),
+ prefix=itertools.chain(config_path, ("config",)),
+ e=e,
+ )
except Exception as e:
- raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
+ raise ConfigError(
+ "Failed to parse config for module %r" % (modulename,),
+ path=itertools.chain(config_path, ("config",)),
+ ) from e
return provider_class, provider_config
@@ -56,3 +84,27 @@ def load_python_module(location: str):
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore
return mod
+
+
+def _wrap_config_error(
+ msg: str, prefix: Iterable[str], e: ConfigError
+) -> "ConfigError":
+ """Wrap a relative ConfigError with a new path
+
+ This is useful when we have a ConfigError with a relative path due to a problem
+ parsing part of the config, and we now need to set it in context.
+ """
+ path = prefix
+ if e.path:
+ path = itertools.chain(prefix, e.path)
+
+ e1 = ConfigError(msg, path)
+
+ # ideally we would set the 'cause' of the new exception to the original exception;
+ # however now that we have merged the path into our own, the stringification of
+ # e will be incorrect, so instead we create a new exception with just the "msg"
+ # part.
+
+ e1.__cause__ = Exception(e.msg)
+ e1.__cause__.__cause__ = e.__cause__
+ return e1
|