diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 2529845c9e..25ea1bcc91 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -110,7 +110,7 @@ class ResponseCache(Generic[T]):
return result.observe()
def wrap(
- self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
+ self, key: T, callback: Callable[..., Any], *args: Any, **kwargs: Any
) -> defer.Deferred:
"""Wrap together a *get* and *set* call, taking care of logcontexts
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 0469e7d120..e81e468899 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -14,11 +14,10 @@
import logging
import math
-from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union
+from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
from sortedcontainers import SortedDict
-from synapse.types import Collection
from synapse.util import caches
logger = logging.getLogger(__name__)
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 6f73b1d56d..abfdc29832 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -15,6 +15,7 @@
import heapq
from itertools import islice
from typing import (
+ Collection,
Dict,
Generator,
Iterable,
@@ -26,8 +27,6 @@ from typing import (
TypeVar,
)
-from synapse.types import Collection
-
T = TypeVar("T")
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index c0e6fb9a60..cd82777f80 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -132,6 +132,38 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]
return host, port
+def valid_id_server_location(id_server: str) -> bool:
+ """Check whether an identity server location, such as the one passed as the
+ `id_server` parameter to `/_matrix/client/r0/account/3pid/bind`, is valid.
+
+ A valid identity server location consists of a valid hostname and optional
+ port number, optionally followed by any number of `/` delimited path
+ components, without any fragment or query string parts.
+
+ Args:
+ id_server: identity server location string to validate
+
+ Returns:
+ True if valid, False otherwise.
+ """
+
+ components = id_server.split("/", 1)
+
+ host = components[0]
+
+ try:
+ parse_and_validate_server_name(host)
+ except ValueError:
+ return False
+
+ if len(components) < 2:
+ # no path
+ return True
+
+ path = components[1]
+ return "#" not in path and "?" not in path
+
+
def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
"""Parse the given string as an MXC URI
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 281c5be4fb..a1cf1960b0 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -18,6 +18,16 @@ import re
logger = logging.getLogger(__name__)
+# it's unclear what the maximum length of an email address is. RFC3696 (as corrected
+# by errata) says:
+# the upper limit on address lengths should normally be considered to be 254.
+#
+# In practice, mail servers appear to be more tolerant and allow 400 characters
+# or so. Let's allow 500, which should be plenty for everyone.
+#
+MAX_EMAIL_ADDRESS_LENGTH = 500
+
+
def check_3pid_allowed(hs, medium, address):
"""Checks whether a given format of 3PID is allowed to be used on this HS
@@ -70,3 +80,23 @@ def canonicalise_email(address: str) -> str:
raise ValueError("Unable to parse email address")
return parts[0].casefold() + "@" + parts[1].lower()
+
+
+def validate_email(address: str) -> str:
+ """Does some basic validation on an email address.
+
+ Returns the canonicalised email, as returned by `canonicalise_email`.
+
+ Raises a ValueError if the email is invalid.
+ """
+ # First we try canonicalising in case that fails
+ address = canonicalise_email(address)
+
+ # Email addresses have to be at least 3 characters.
+ if len(address) < 3:
+ raise ValueError("Unable to parse email address")
+
+ if len(address) > MAX_EMAIL_ADDRESS_LENGTH:
+ raise ValueError("Unable to parse email address")
+
+ return address
|