summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/iterutils.py53
-rw-r--r--synapse/util/stringutils.py111
2 files changed, 163 insertions, 1 deletions
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 06faeebe7f..6ef2b008a4 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -13,8 +13,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import heapq
 from itertools import islice
-from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+from typing import (
+    Dict,
+    Generator,
+    Iterable,
+    Iterator,
+    Mapping,
+    Sequence,
+    Set,
+    Tuple,
+    TypeVar,
+)
+
+from synapse.types import Collection
 
 T = TypeVar("T")
 
@@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
     If the input is empty, no chunks are returned.
     """
     return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
+
+
+def sorted_topologically(
+    nodes: Iterable[T], graph: Mapping[T, Collection[T]],
+) -> Generator[T, None, None]:
+    """Given a set of nodes and a graph, yield the nodes in toplogical order.
+
+    For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`.
+    """
+
+    # This is implemented by Kahn's algorithm.
+
+    degree_map = {node: 0 for node in nodes}
+    reverse_graph = {}  # type: Dict[T, Set[T]]
+
+    for node, edges in graph.items():
+        if node not in degree_map:
+            continue
+
+        for edge in edges:
+            if edge in degree_map:
+                degree_map[node] += 1
+
+            reverse_graph.setdefault(edge, set()).add(node)
+        reverse_graph.setdefault(node, set())
+
+    zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+    heapq.heapify(zero_degree)
+
+    while zero_degree:
+        node = heapq.heappop(zero_degree)
+        yield node
+
+        for edge in reverse_graph.get(node, []):
+            if edge in degree_map:
+                degree_map[edge] -= 1
+                if degree_map[edge] == 0:
+                    heapq.heappush(zero_degree, edge)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 61d96a6c28..f8038bf861 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -18,6 +18,7 @@ import random
 import re
 import string
 from collections.abc import Iterable
+from typing import Optional, Tuple
 
 from synapse.api.errors import Codes, SynapseError
 
@@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
 client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
 
+# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
+# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
+# says "there is no grammar for media ids"
+#
+# The server_name part of this is purposely lax: use parse_and_validate_mxc for
+# additional validation.
+#
+MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
+
 # random_string and random_string_with_symbols are used for a range of things,
 # some cryptographically important, some less so. We use SystemRandom to make sure
 # we get cryptographically-secure randoms.
@@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret):
         )
 
 
+def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+    """Split a server name into host/port parts.
+
+    Args:
+        server_name: server name to parse
+
+    Returns:
+        host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    try:
+        if server_name[-1] == "]":
+            # ipv6 literal, hopefully
+            return server_name, None
+
+        domain_port = server_name.rsplit(":", 1)
+        domain = domain_port[0]
+        port = int(domain_port[1]) if domain_port[1:] else None
+        return domain, port
+    except Exception:
+        raise ValueError("Invalid server name '%s'" % server_name)
+
+
+VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
+
+
+def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+    """Split a server name into host/port parts and do some basic validation.
+
+    Args:
+        server_name: server name to parse
+
+    Returns:
+        host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    host, port = parse_server_name(server_name)
+
+    # these tests don't need to be bulletproof as we'll find out soon enough
+    # if somebody is giving us invalid data. What we *do* need is to be sure
+    # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+    # look for ipv6 literals
+    if host[0] == "[":
+        if host[-1] != "]":
+            raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
+        return host, port
+
+    # otherwise it should only be alphanumerics.
+    if not VALID_HOST_REGEX.match(host):
+        raise ValueError(
+            "Server name '%s' contains invalid characters" % (server_name,)
+        )
+
+    return host, port
+
+
+def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
+    """Parse the given string as an MXC URI
+
+    Checks that the "server name" part is a valid server name
+
+    Args:
+        mxc: the (alleged) MXC URI to be checked
+    Returns:
+        hostname, port, media id
+    Raises:
+        ValueError if the URI cannot be parsed
+    """
+    m = MXC_REGEX.match(mxc)
+    if not m:
+        raise ValueError("mxc URI %r did not match expected format" % (mxc,))
+    server_name = m.group(1)
+    media_id = m.group(2)
+    host, port = parse_and_validate_server_name(server_name)
+    return host, port, media_id
+
+
 def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
     """If iterable has maxitems or fewer, return the stringification of a list
     containing those items.
@@ -75,3 +167,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
     if len(items) <= maxitems:
         return str(items)
     return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
+
+
+def strtobool(val: str) -> bool:
+    """Convert a string representation of truth to True or False
+
+    True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
+    are 'n', 'no', 'f', 'false', 'off', and '0'.  Raises ValueError if
+    'val' is anything else.
+
+    This is lifted from distutils.util.strtobool, with the exception that it actually
+    returns a bool, rather than an int.
+    """
+    val = val.lower()
+    if val in ("y", "yes", "t", "true", "on", "1"):
+        return True
+    elif val in ("n", "no", "f", "false", "off", "0"):
+        return False
+    else:
+        raise ValueError("invalid truth value %r" % (val,))