summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorAndrew Morgan <1342360+anoadragon453@users.noreply.github.com>2020-04-22 14:23:10 +0100
committerGitHub <noreply@github.com>2020-04-22 14:23:10 +0100
commite66bbf7a9db5cbd94ecd59b8375ab398d6951344 (patch)
tree8661d0190c6aca4bac41355176c1fb2742e9d4ae /synapse/handlers
parentRemove unnecessary shadow server code (diff)
downloadsynapse-e66bbf7a9db5cbd94ecd59b8375ab398d6951344.tar.xz
Fix and refactor rewritten IS url feature. Add sample config docs (#40)
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/identity.py155
1 files changed, 72 insertions, 83 deletions
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 6c369b78aa..e8a6cf7788 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -45,8 +45,6 @@ from ._base import BaseHandler
 
 logger = logging.getLogger(__name__)
 
-id_server_scheme = "https://"
-
 
 class IdentityHandler(BaseHandler):
     def __init__(self, hs):
@@ -69,13 +67,13 @@ class IdentityHandler(BaseHandler):
         self._enable_lookup = hs.config.enable_3pid_lookup
 
     @defer.inlineCallbacks
-    def threepid_from_creds(self, id_server, creds):
+    def threepid_from_creds(self, id_server_url, creds):
         """
         Retrieve and validate a threepid identifier from a "credentials" dictionary against a
         given identity server
 
         Args:
-            id_server (str): The identity server to validate 3PIDs against. Must be a
+            id_server_url (str): The identity server to validate 3PIDs against. Must be a
                 complete URL including the protocol (http(s)://)
 
             creds (dict[str, str]): Dictionary containing the following keys:
@@ -104,10 +102,10 @@ class IdentityHandler(BaseHandler):
 
         # if we have a rewrite rule set for the identity server,
         # apply it now.
-        id_server = self.rewrite_id_server_url(id_server)
+        id_server_url = self.rewrite_id_server_url(id_server_url)
 
-        url = "https://%s%s" % (
-            id_server,
+        url = "%s%s" % (
+            id_server_url,
             "/_matrix/identity/api/v1/3pid/getValidated3pid",
         )
 
@@ -118,7 +116,7 @@ class IdentityHandler(BaseHandler):
         except HttpResponseException as e:
             logger.info(
                 "%s returned %i for threepid validation for: %s",
-                id_server,
+                id_server_url,
                 e.code,
                 creds,
             )
@@ -132,7 +130,7 @@ class IdentityHandler(BaseHandler):
         if "medium" in data:
             return data
 
-        logger.info("%s reported non-validated threepid: %s", id_server, creds)
+        logger.info("%s reported non-validated threepid: %s", id_server_url, creds)
         return None
 
     @defer.inlineCallbacks
@@ -167,18 +165,18 @@ class IdentityHandler(BaseHandler):
         # if we have a rewrite rule set for the identity server,
         # apply it now, but only for sending the request (not
         # storing in the database).
-        id_server_host = self.rewrite_id_server_url(id_server)
+        id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
 
         # Decide which API endpoint URLs to use
         headers = {}
         bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
         if use_v2:
-            bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server_host,)
+            bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,)
             headers["Authorization"] = create_id_access_token_header(
                 id_access_token
             )
         else:
-            bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server_host,)
+            bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,)
 
         try:
             # Use the blacklisting http client as this call is only to identity servers
@@ -265,9 +263,6 @@ class IdentityHandler(BaseHandler):
             Deferred[bool]: True on success, otherwise False if the identity
             server doesn't support unbinding
         """
-        url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
-        url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
-
         content = {
             "mxid": mxid,
             "threepid": {"medium": threepid["medium"], "address": threepid["address"]},
@@ -276,6 +271,7 @@ class IdentityHandler(BaseHandler):
         # we abuse the federation http client to sign the request, but we have to send it
         # using the normal http client since we don't want the SRV lookup and want normal
         # 'browser-like' HTTPS.
+        url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
         auth_headers = self.federation_http_client.build_auth_headers(
             destination=None,
             method="POST",
@@ -290,9 +286,9 @@ class IdentityHandler(BaseHandler):
         #
         # Note that destination_is has to be the real id_server, not
         # the server we connect to.
-        id_server = self.rewrite_id_server_url(id_server)
+        id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
 
-        url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
+        url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,)
 
         try:
             # Use the blacklisting http client as this call is only to identity servers
@@ -408,33 +404,33 @@ class IdentityHandler(BaseHandler):
 
         return session_id
 
-    def rewrite_id_server_url(self, url: str) -> str:
-        """Given an identity server URL, rewrite it according to the
-        rewrite_identity_server_urls config option
+    def rewrite_id_server_url(self, url: str, add_https=False) -> str:
+        """Given an identity server URL, optionally add a protocol scheme
+        before rewriting it according to the rewrite_identity_server_urls
+        config option
 
-        First removes the protocol scheme from the URL if provided.
-        Then checks for a rewritten URL. If found, returns the new URL.
-        Otherwise, returns the original URL.
+        Adds https:// to the URL if specified, then tries to rewrite the
+        url. Returns either the rewritten URL or the URL with optional
+        protocol scheme additions.
         """
         rewritten_url = url
-        if url.startswith("http://"):
-            rewritten_url = url[7:]
-        elif url.startswith("https://"):
-            rewritten_url = url[8:]
-
-        return self.rewrite_identity_server_urls.get(rewritten_url, url)
+        if add_https:
+            rewritten_url = "https://" + rewritten_url
 
+        rewritten_url = self.rewrite_identity_server_urls.get(rewritten_url, rewritten_url)
+        logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url)
+        return rewritten_url
 
     @defer.inlineCallbacks
     def requestEmailToken(
-        self, id_server, email, client_secret, send_attempt, next_link=None
+        self, id_server_url, email, client_secret, send_attempt, next_link=None
     ):
         """
         Request an external server send an email on our behalf for the purposes of threepid
         validation.
 
         Args:
-            id_server (str): The identity server to proxy to
+            id_server_url (str): The identity server to proxy to
             email (str): The email to send the message to
             client_secret (str): The unique client_secret sends by the user
             send_attempt (int): Which attempt this is
@@ -451,7 +447,7 @@ class IdentityHandler(BaseHandler):
 
         # if we have a rewrite rule set for the identity server,
         # apply it now.
-        id_server = self.rewrite_id_server_url(id_server)
+        id_server_url = self.rewrite_id_server_url(id_server_url)
 
         if next_link:
             params["next_link"] = next_link
@@ -467,7 +463,7 @@ class IdentityHandler(BaseHandler):
 
         try:
             data = yield self.http_client.post_json_get_json(
-                id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
+                "%s/_matrix/identity/api/v1/validate/email/requestToken" % (id_server_url,),
                 params,
             )
             return data
@@ -480,7 +476,7 @@ class IdentityHandler(BaseHandler):
     @defer.inlineCallbacks
     def requestMsisdnToken(
         self,
-        id_server,
+        id_server_url,
         country,
         phone_number,
         client_secret,
@@ -491,7 +487,7 @@ class IdentityHandler(BaseHandler):
         Request an external server send an SMS message on our behalf for the purposes of
         threepid validation.
         Args:
-            id_server (str): The identity server to proxy to
+            id_server_url (str): The identity server to proxy to
             country (str): The country code of the phone number
             phone_number (str): The number to send the message to
             client_secret (str): The unique client_secret sends by the user
@@ -521,10 +517,10 @@ class IdentityHandler(BaseHandler):
 
         # if we have a rewrite rule set for the identity server,
         # apply it now.
-        id_server = self.rewrite_id_server_url(id_server)
+        id_server_url = self.rewrite_id_server_url(id_server_url)
         try:
             data = yield self.http_client.post_json_get_json(
-                id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
+                "%s/_matrix/identity/api/v1/validate/msisdn/requestToken" % (id_server_url,),
                 params,
             )
         except HttpResponseException as e:
@@ -646,11 +642,11 @@ class IdentityHandler(BaseHandler):
                 403, "Looking up third-party identifiers is denied from this server"
             )
 
-        target = self.rewrite_id_server_url(id_server)
+        id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
 
         try:
             data = yield self.http_client.get_json(
-                "https://%s/_matrix/identity/api/v1/lookup" % (target,),
+                "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
                 {"medium": medium, "address": address},
             )
 
@@ -663,7 +659,7 @@ class IdentityHandler(BaseHandler):
             logger.info("Proxied lookup failed: %r", e)
             raise e.to_synapse_error()
         except IOError as e:
-            logger.info("Failed to contact %r: %s", id_server, e)
+            logger.info("Failed to contact %s: %s", id_server, e)
             raise ProxiedRequestError(503, "Failed to contact identity server")
 
         defer.returnValue(data)
@@ -688,11 +684,11 @@ class IdentityHandler(BaseHandler):
                 403, "Looking up third-party identifiers is denied from this server"
             )
 
-        target = self.rewrite_id_server_url(id_server)
+        id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
 
         try:
             data = yield self.http_client.post_json_get_json(
-                "https://%s/_matrix/identity/api/v1/bulk_lookup" % (target,),
+                "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,),
                 {"threepids": threepids},
             )
 
@@ -700,7 +696,7 @@ class IdentityHandler(BaseHandler):
             logger.info("Proxied lookup failed: %r", e)
             raise e.to_synapse_error()
         except IOError as e:
-            logger.info("Failed to contact %r: %s", id_server, e)
+            logger.info("Failed to contact %s: %s", id_server, e)
             raise ProxiedRequestError(503, "Failed to contact identity server")
 
         defer.returnValue(data)
@@ -721,12 +717,12 @@ class IdentityHandler(BaseHandler):
             str|None: the matrix ID of the 3pid, or None if it is not recognized.
         """
         # Rewrite id_server URL if necessary
-        id_server_url = self.rewrite_id_server_url(id_server)
+        id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
 
         if id_access_token is not None:
             try:
                 results = yield self._lookup_3pid_v2(
-                    id_server, id_server_url, id_access_token, medium, address
+                    id_server_url, id_access_token, medium, address
                 )
                 return results
 
@@ -762,7 +758,7 @@ class IdentityHandler(BaseHandler):
         """
         try:
             data = yield self.http_client.get_json(
-                "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server_url),
+                "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
                 {"medium": medium, "address": address},
             )
 
@@ -779,13 +775,11 @@ class IdentityHandler(BaseHandler):
         return None
 
     @defer.inlineCallbacks
-    def _lookup_3pid_v2(self, id_server, id_server_url, id_access_token, medium, address):
+    def _lookup_3pid_v2(self, id_server_url, id_access_token, medium, address):
         """Looks up a 3pid in the passed identity server using v2 lookup.
 
         Args:
-            id_server (str): The server name (including port, if required)
-                of the identity server to use.
-            id_server_url (str): The actual, reachable domain of the id server
+            id_server_url (str): The protocol scheme and domain of the id server
             id_access_token (str): The access token to authenticate to the identity server with
             medium (str): The type of the third party identifier (e.g. "email").
             address (str): The third party identifier (e.g. "foo@example.com").
@@ -796,7 +790,7 @@ class IdentityHandler(BaseHandler):
         # Check what hashing details are supported by this identity server
         try:
             hash_details = yield self.http_client.get_json(
-                "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server_url),
+                "%s/_matrix/identity/v2/hash_details" % (id_server_url,),
                 {"access_token": id_access_token},
             )
         except TimeoutError:
@@ -804,15 +798,14 @@ class IdentityHandler(BaseHandler):
 
         if not isinstance(hash_details, dict):
             logger.warning(
-                "Got non-dict object when checking hash details of %s%s: %s",
-                id_server_scheme,
-                id_server,
+                "Got non-dict object when checking hash details of %s: %s",
+                id_server_url,
                 hash_details,
             )
             raise SynapseError(
                 400,
-                "Non-dict object from %s%s during v2 hash_details request: %s"
-                % (id_server_scheme, id_server, hash_details),
+                "Non-dict object from %s during v2 hash_details request: %s"
+                % (id_server_url, hash_details),
                 )
 
         # Extract information from hash_details
@@ -826,8 +819,8 @@ class IdentityHandler(BaseHandler):
         ):
             raise SynapseError(
                 400,
-                "Invalid hash details received from identity server %s%s: %s"
-                % (id_server_scheme, id_server, hash_details),
+                "Invalid hash details received from identity server %s: %s"
+                % (id_server_url, hash_details),
                 )
 
         # Check if any of the supported lookup algorithms are present
@@ -849,7 +842,7 @@ class IdentityHandler(BaseHandler):
         else:
             logger.warning(
                 "None of the provided lookup algorithms of %s are supported: %s",
-                id_server,
+                id_server_url,
                 supported_lookup_algorithms,
             )
             raise SynapseError(
@@ -863,7 +856,7 @@ class IdentityHandler(BaseHandler):
 
         try:
             lookup_results = yield self.http_client.post_json_get_json(
-                "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server_url),
+                "%s/_matrix/identity/v2/lookup" % (id_server_url,),
                 {
                     "addresses": [lookup_value],
                     "algorithm": lookup_algorithm,
@@ -891,30 +884,30 @@ class IdentityHandler(BaseHandler):
         return mxid
 
     @defer.inlineCallbacks
-    def _verify_any_signature(self, data, server_hostname):
-        if server_hostname not in data["signatures"]:
-            raise AuthError(401, "No signature from server %s" % (server_hostname,))
+    def _verify_any_signature(self, data, id_server):
+        if id_server not in data["signatures"]:
+            raise AuthError(401, "No signature from server %s" % (id_server,))
 
-        for key_name, signature in data["signatures"][server_hostname].items():
-            target = self.rewrite_id_server_url(server_hostname)
+        for key_name, signature in data["signatures"][id_server].items():
+            id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
 
             key_data = yield self.http_client.get_json(
-                "https://%s/_matrix/identity/api/v1/pubkey/%s" % (target, key_name)
+                "%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_url, key_name)
             )
             if "public_key" not in key_data:
                 raise AuthError(
-                    401, "No public key named %s from %s" % (key_name, server_hostname)
+                    401, "No public key named %s from %s" % (key_name, id_server)
                 )
             verify_signed_json(
                 data,
-                server_hostname,
+                id_server,
                 decode_verify_key_bytes(
                     key_name, decode_base64(key_data["public_key"])
                 ),
             )
             return
 
-        raise AuthError(401, "No signature from server %s" % (server_hostname,))
+        raise AuthError(401, "No signature from server %s" % (id_server,))
 
     @defer.inlineCallbacks
     def ask_id_server_for_third_party_invite(
@@ -977,17 +970,16 @@ class IdentityHandler(BaseHandler):
         }
 
         # Rewrite the identity server URL if necessary
-        id_server = self.rewrite_id_server_url(id_server)
+        id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
 
         # Add the identity service access token to the JSON body and use the v2
         # Identity Service endpoints if id_access_token is present
         data = None
-        base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
+        base_url = "%s/_matrix/identity" % (id_server_url,)
 
         if id_access_token:
-            key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
-                id_server_scheme,
-                id_server,
+            key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % (
+                id_server_url,
             )
 
             # Attempt a v2 lookup
@@ -1006,9 +998,8 @@ class IdentityHandler(BaseHandler):
                     raise e
 
         if data is None:
-            key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
-                id_server_scheme,
-                id_server,
+            key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % (
+                id_server_url,
             )
             url = base_url + "/api/v1/store-invite"
 
@@ -1020,9 +1011,8 @@ class IdentityHandler(BaseHandler):
                 raise SynapseError(500, "Timed out contacting identity server")
             except HttpResponseException as e:
                 logger.warning(
-                    "Error trying to call /store-invite on %s%s: %s",
-                    id_server_scheme,
-                    id_server,
+                    "Error trying to call /store-invite on %s: %s",
+                    id_server_url,
                     e,
                 )
 
@@ -1036,10 +1026,9 @@ class IdentityHandler(BaseHandler):
                     )
                 except HttpResponseException as e:
                     logger.warning(
-                        "Error calling /store-invite on %s%s with fallback "
+                        "Error calling /store-invite on %s with fallback "
                         "encoding: %s",
-                        id_server_scheme,
-                        id_server,
+                        id_server_url,
                         e,
                     )
                     raise e