summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/identity.py29
-rw-r--r--synapse/util/stringutils.py32
2 files changed, 58 insertions, 3 deletions
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 87a8b89237..0b3b1fadb5 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 
 """Utilities for interacting with Identity Servers"""
-
 import logging
 import urllib.parse
 from typing import Awaitable, Callable, Dict, List, Optional, Tuple
@@ -34,7 +33,11 @@ from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, Requester
 from synapse.util import json_decoder
 from synapse.util.hash import sha256_and_url_safe_base64
-from synapse.util.stringutils import assert_valid_client_secret, random_string
+from synapse.util.stringutils import (
+    assert_valid_client_secret,
+    random_string,
+    valid_id_server_location,
+)
 
 from ._base import BaseHandler
 
@@ -172,6 +175,11 @@ class IdentityHandler(BaseHandler):
                 server with, if necessary. Required if use_v2 is true
             use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
 
+        Raises:
+            SynapseError: On any of the following conditions
+                - the supplied id_server is not a valid identity server name
+                - we failed to contact the supplied identity server
+
         Returns:
             The response from the identity server
         """
@@ -181,6 +189,12 @@ class IdentityHandler(BaseHandler):
         if id_access_token is None:
             use_v2 = False
 
+        if not valid_id_server_location(id_server):
+            raise SynapseError(
+                400,
+                "id_server must be a valid hostname with optional port and path components",
+            )
+
         # Decide which API endpoint URLs to use
         headers = {}
         bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
@@ -269,12 +283,21 @@ class IdentityHandler(BaseHandler):
             id_server: Identity server to unbind from
 
         Raises:
-            SynapseError: If we failed to contact the identity server
+            SynapseError: On any of the following conditions
+                - the supplied id_server is not a valid identity server name
+                - we failed to contact the supplied identity server
 
         Returns:
             True on success, otherwise False if the identity
             server doesn't support unbinding
         """
+
+        if not valid_id_server_location(id_server):
+            raise SynapseError(
+                400,
+                "id_server must be a valid hostname with optional port and path components",
+            )
+
         url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
         url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
 
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index c0e6fb9a60..cd82777f80 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -132,6 +132,38 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]
     return host, port
 
 
+def valid_id_server_location(id_server: str) -> bool:
+    """Check whether an identity server location, such as the one passed as the
+    `id_server` parameter to `/_matrix/client/r0/account/3pid/bind`, is valid.
+
+    A valid identity server location consists of a valid hostname and optional
+    port number, optionally followed by any number of `/` delimited path
+    components, without any fragment or query string parts.
+
+    Args:
+        id_server: identity server location string to validate
+
+    Returns:
+        True if valid, False otherwise.
+    """
+
+    components = id_server.split("/", 1)
+
+    host = components[0]
+
+    try:
+        parse_and_validate_server_name(host)
+    except ValueError:
+        return False
+
+    if len(components) < 2:
+        # no path
+        return True
+
+    path = components[1]
+    return "#" not in path and "?" not in path
+
+
 def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
     """Parse the given string as an MXC URI