diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 382f0cf3f0..9a873c8e8e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -15,10 +15,12 @@
# limitations under the License.
import collections
+import inspect
import logging
from contextlib import contextmanager
from typing import (
Any,
+ Awaitable,
Callable,
Dict,
Hashable,
@@ -542,11 +544,11 @@ class DoneAwaitable:
raise StopIteration(self.value)
-def maybe_awaitable(value):
+def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
"""Convert a value to an awaitable if not already an awaitable.
"""
-
- if hasattr(value, "__await__"):
+ if inspect.isawaitable(value):
+ assert isinstance(value, Awaitable)
return value
return DoneAwaitable(value)
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 601305487c..1adc92eb90 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -105,7 +105,7 @@ class DeferredCache(Generic[KT, VT]):
keylen=keylen,
cache_name=name,
cache_type=cache_type,
- size_callback=(lambda d: len(d)) if iterable else None,
+ size_callback=(lambda d: len(d) or 1) if iterable else None,
metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config,
) # type: LruCache[KT, VT]
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index f73e95393c..a6ee9edaec 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -105,10 +105,7 @@ class Signal:
async def do(observer):
try:
- result = observer(*args, **kwargs)
- if inspect.isawaitable(result):
- result = await result
- return result
+ return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
logger.warning(
"%s signal observer %s failed: %r", self.name, observer, e,
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 06faeebe7f..8d2411513f 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -13,8 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import heapq
from itertools import islice
-from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+from typing import (
+ Dict,
+ Generator,
+ Iterable,
+ Iterator,
+ Mapping,
+ Sequence,
+ Set,
+ Tuple,
+ TypeVar,
+)
+
+from synapse.types import Collection
T = TypeVar("T")
@@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
If the input is empty, no chunks are returned.
"""
return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
+
+
+def sorted_topologically(
+ nodes: Iterable[T], graph: Mapping[T, Collection[T]],
+) -> Generator[T, None, None]:
+ """Given a set of nodes and a graph, yield the nodes in toplogical order.
+
+ For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`.
+ """
+
+ # This is implemented by Kahn's algorithm.
+
+ degree_map = {node: 0 for node in nodes}
+ reverse_graph = {} # type: Dict[T, Set[T]]
+
+ for node, edges in graph.items():
+ if node not in degree_map:
+ continue
+
+ for edge in set(edges):
+ if edge in degree_map:
+ degree_map[node] += 1
+
+ reverse_graph.setdefault(edge, set()).add(node)
+ reverse_graph.setdefault(node, set())
+
+ zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+ heapq.heapify(zero_degree)
+
+ while zero_degree:
+ node = heapq.heappop(zero_degree)
+ yield node
+
+ for edge in reverse_graph.get(node, []):
+ if edge in degree_map:
+ degree_map[edge] -= 1
+ if degree_map[edge] == 0:
+ heapq.heappush(zero_degree, edge)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ffdea0de8d..f4de6b9f54 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -108,7 +108,16 @@ class Measure:
def __init__(self, clock, name):
self.clock = clock
self.name = name
- parent_context = current_context()
+ curr_context = current_context()
+ if not curr_context:
+ logger.warning(
+ "Starting metrics collection %r from sentinel context: metrics will be lost",
+ name,
+ )
+ parent_context = None
+ else:
+ assert isinstance(curr_context, LoggingContext)
+ parent_context = curr_context
self._logging_context = LoggingContext(
"Measure[%s]" % (self.name,), parent_context
)
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 94b59afb38..09b094ded7 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -15,28 +15,57 @@
import importlib
import importlib.util
+import itertools
+from typing import Any, Iterable, Tuple, Type
+
+import jsonschema
from synapse.config._base import ConfigError
+from synapse.config._util import json_error_to_config_error
-def load_module(provider):
+def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
""" Loads a synapse module with its config
- Take a dict with keys 'module' (the module name) and 'config'
- (the config dict).
+
+ Args:
+ provider: a dict with keys 'module' (the module name) and 'config'
+ (the config dict).
+ config_path: the path within the config file. This will be used as a basis
+ for any error message.
Returns
Tuple of (provider class, parsed config object)
"""
+
+ modulename = provider.get("module")
+ if not isinstance(modulename, str):
+ raise ConfigError(
+ "expected a string", path=itertools.chain(config_path, ("module",))
+ )
+
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
- module, clz = provider["module"].rsplit(".", 1)
+ module, clz = modulename.rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
+ # Load the module config. If None, pass an empty dictionary instead
+ module_config = provider.get("config") or {}
try:
- provider_config = provider_class.parse_config(provider.get("config"))
+ provider_config = provider_class.parse_config(module_config)
+ except jsonschema.ValidationError as e:
+ raise json_error_to_config_error(e, itertools.chain(config_path, ("config",)))
+ except ConfigError as e:
+ raise _wrap_config_error(
+ "Failed to parse config for module %r" % (modulename,),
+ prefix=itertools.chain(config_path, ("config",)),
+ e=e,
+ )
except Exception as e:
- raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
+ raise ConfigError(
+ "Failed to parse config for module %r" % (modulename,),
+ path=itertools.chain(config_path, ("config",)),
+ ) from e
return provider_class, provider_config
@@ -56,3 +85,27 @@ def load_python_module(location: str):
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore
return mod
+
+
+def _wrap_config_error(
+ msg: str, prefix: Iterable[str], e: ConfigError
+) -> "ConfigError":
+ """Wrap a relative ConfigError with a new path
+
+ This is useful when we have a ConfigError with a relative path due to a problem
+ parsing part of the config, and we now need to set it in context.
+ """
+ path = prefix
+ if e.path:
+ path = itertools.chain(prefix, e.path)
+
+ e1 = ConfigError(msg, path)
+
+ # ideally we would set the 'cause' of the new exception to the original exception;
+ # however now that we have merged the path into our own, the stringification of
+ # e will be incorrect, so instead we create a new exception with just the "msg"
+ # part.
+
+ e1.__cause__ = Exception(e.msg)
+ e1.__cause__.__cause__ = e.__cause__
+ return e1
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 61d96a6c28..f8038bf861 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -18,6 +18,7 @@ import random
import re
import string
from collections.abc import Iterable
+from typing import Optional, Tuple
from synapse.api.errors import Codes, SynapseError
@@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
+# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
+# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
+# says "there is no grammar for media ids"
+#
+# The server_name part of this is purposely lax: use parse_and_validate_mxc for
+# additional validation.
+#
+MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
+
# random_string and random_string_with_symbols are used for a range of things,
# some cryptographically important, some less so. We use SystemRandom to make sure
# we get cryptographically-secure randoms.
@@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret):
)
+def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+ """Split a server name into host/port parts.
+
+ Args:
+ server_name: server name to parse
+
+ Returns:
+ host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ try:
+ if server_name[-1] == "]":
+ # ipv6 literal, hopefully
+ return server_name, None
+
+ domain_port = server_name.rsplit(":", 1)
+ domain = domain_port[0]
+ port = int(domain_port[1]) if domain_port[1:] else None
+ return domain, port
+ except Exception:
+ raise ValueError("Invalid server name '%s'" % server_name)
+
+
+VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
+
+
+def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+ """Split a server name into host/port parts and do some basic validation.
+
+ Args:
+ server_name: server name to parse
+
+ Returns:
+ host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ host, port = parse_server_name(server_name)
+
+ # these tests don't need to be bulletproof as we'll find out soon enough
+ # if somebody is giving us invalid data. What we *do* need is to be sure
+ # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+ # look for ipv6 literals
+ if host[0] == "[":
+ if host[-1] != "]":
+ raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
+ return host, port
+
+ # otherwise it should only be alphanumerics.
+ if not VALID_HOST_REGEX.match(host):
+ raise ValueError(
+ "Server name '%s' contains invalid characters" % (server_name,)
+ )
+
+ return host, port
+
+
+def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
+ """Parse the given string as an MXC URI
+
+ Checks that the "server name" part is a valid server name
+
+ Args:
+ mxc: the (alleged) MXC URI to be checked
+ Returns:
+ hostname, port, media id
+ Raises:
+ ValueError if the URI cannot be parsed
+ """
+ m = MXC_REGEX.match(mxc)
+ if not m:
+ raise ValueError("mxc URI %r did not match expected format" % (mxc,))
+ server_name = m.group(1)
+ media_id = m.group(2)
+ host, port = parse_and_validate_server_name(server_name)
+ return host, port, media_id
+
+
def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
"""If iterable has maxitems or fewer, return the stringification of a list
containing those items.
@@ -75,3 +167,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
if len(items) <= maxitems:
return str(items)
return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
+
+
+def strtobool(val: str) -> bool:
+ """Convert a string representation of truth to True or False
+
+ True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
+ are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
+ 'val' is anything else.
+
+ This is lifted from distutils.util.strtobool, with the exception that it actually
+ returns a bool, rather than an int.
+ """
+ val = val.lower()
+ if val in ("y", "yes", "t", "true", "on", "1"):
+ return True
+ elif val in ("n", "no", "f", "false", "off", "0"):
+ return False
+ else:
+ raise ValueError("invalid truth value %r" % (val,))
diff --git a/synapse/util/templates.py b/synapse/util/templates.py
new file mode 100644
index 0000000000..392dae4a40
--- /dev/null
+++ b/synapse/util/templates.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for dealing with jinja2 templates"""
+
+import time
+import urllib.parse
+from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
+
+import jinja2
+
+if TYPE_CHECKING:
+ from synapse.config.homeserver import HomeServerConfig
+
+
+def build_jinja_env(
+ template_search_directories: Iterable[str],
+ config: "HomeServerConfig",
+ autoescape: Union[bool, Callable[[str], bool], None] = None,
+) -> jinja2.Environment:
+ """Set up a Jinja2 environment to load templates from the given search path
+
+ The returned environment defines the following filters:
+ - format_ts: formats timestamps as strings in the server's local timezone
+ (XXX: why is that useful??)
+ - mxc_to_http: converts mxc: uris to http URIs. Args are:
+ (uri, width, height, resize_method="crop")
+
+ and the following global variables:
+ - server_name: matrix server name
+
+ Args:
+ template_search_directories: directories to search for templates
+
+ config: homeserver config, for things like `server_name` and `public_baseurl`
+
+ autoescape: whether template variables should be autoescaped. bool, or
+ a function mapping from template name to bool. Defaults to escaping templates
+ whose names end in .html, .xml or .htm.
+
+ Returns:
+ jinja environment
+ """
+
+ if autoescape is None:
+ autoescape = jinja2.select_autoescape()
+
+ loader = jinja2.FileSystemLoader(template_search_directories)
+ env = jinja2.Environment(loader=loader, autoescape=autoescape)
+
+ # Update the environment with our custom filters
+ env.filters.update(
+ {
+ "format_ts": _format_ts_filter,
+ "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl),
+ }
+ )
+
+ # common variables for all templates
+ env.globals.update({"server_name": config.server_name})
+
+ return env
+
+
+def _create_mxc_to_http_filter(
+ public_baseurl: Optional[str],
+) -> Callable[[str, int, int, str], str]:
+ """Create and return a jinja2 filter that converts MXC urls to HTTP
+
+ Args:
+ public_baseurl: The public, accessible base URL of the homeserver
+ """
+
+ def mxc_to_http_filter(
+ value: str, width: int, height: int, resize_method: str = "crop"
+ ) -> str:
+ if not public_baseurl:
+ raise RuntimeError(
+ "public_baseurl must be set in the homeserver config to convert MXC URLs to HTTP URLs."
+ )
+
+ if value[0:6] != "mxc://":
+ return ""
+
+ server_and_media_id = value[6:]
+ fragment = None
+ if "#" in server_and_media_id:
+ server_and_media_id, fragment = server_and_media_id.split("#", 1)
+ fragment = "#" + fragment
+
+ params = {"width": width, "height": height, "method": resize_method}
+ return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+ public_baseurl,
+ server_and_media_id,
+ urllib.parse.urlencode(params),
+ fragment or "",
+ )
+
+ return mxc_to_http_filter
+
+
+def _format_ts_filter(value: int, format: str):
+ return time.strftime(format, time.localtime(value / 1000))
|