summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2020-04-20 19:11:58 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2020-04-20 19:13:11 +0100
commit99eb246b8ab5e9e2acce2e6eda8f2e82cd28e92d (patch)
tree0732f77cc165bb2a251f84b4a768eaab4bafc99b /synapse/handlers
parentDon't verify signatures using re-written id_server url (diff)
downloadsynapse-99eb246b8ab5e9e2acce2e6eda8f2e82cd28e92d.tar.xz
Consolidate id_server URL translation
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/identity.py61
-rw-r--r--synapse/handlers/room_member.py1
2 files changed, 27 insertions, 35 deletions
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index c48bbbbd93..6c369b78aa 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -104,8 +104,7 @@ class IdentityHandler(BaseHandler):
 
         # if we have a rewrite rule set for the identity server,
         # apply it now.
-        if id_server in self.rewrite_identity_server_urls:
-            id_server = self.rewrite_identity_server_urls[id_server]
+        id_server = self.rewrite_id_server_url(id_server)
 
         url = "https://%s%s" % (
             id_server,
@@ -168,10 +167,7 @@ class IdentityHandler(BaseHandler):
         # if we have a rewrite rule set for the identity server,
         # apply it now, but only for sending the request (not
         # storing in the database).
-        if id_server in self.rewrite_identity_server_urls:
-            id_server_host = self.rewrite_identity_server_urls[id_server]
-        else:
-            id_server_host = id_server
+        id_server_host = self.rewrite_id_server_url(id_server)
 
         # Decide which API endpoint URLs to use
         headers = {}
@@ -294,8 +290,7 @@ class IdentityHandler(BaseHandler):
         #
         # Note that destination_is has to be the real id_server, not
         # the server we connect to.
-        if id_server in self.rewrite_identity_server_urls:
-            id_server = self.rewrite_identity_server_urls[id_server]
+        id_server = self.rewrite_id_server_url(id_server)
 
         url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
 
@@ -413,6 +408,23 @@ class IdentityHandler(BaseHandler):
 
         return session_id
 
+    def rewrite_id_server_url(self, url: str) -> str:
+        """Given an identity server URL, rewrite it according to the
+        rewrite_identity_server_urls config option
+
+        First removes the protocol scheme from the URL if provided.
+        Then checks for a rewritten URL. If found, returns the new URL.
+        Otherwise, returns the original URL.
+        """
+        rewritten_url = url
+        if url.startswith("http://"):
+            rewritten_url = url[7:]
+        elif url.startswith("https://"):
+            rewritten_url = url[8:]
+
+        return self.rewrite_identity_server_urls.get(rewritten_url, url)
+
+
     @defer.inlineCallbacks
     def requestEmailToken(
         self, id_server, email, client_secret, send_attempt, next_link=None
@@ -439,8 +451,7 @@ class IdentityHandler(BaseHandler):
 
         # if we have a rewrite rule set for the identity server,
         # apply it now.
-        if id_server in self.rewrite_identity_server_urls:
-            id_server = self.rewrite_identity_server_urls[id_server]
+        id_server = self.rewrite_id_server_url(id_server)
 
         if next_link:
             params["next_link"] = next_link
@@ -510,8 +521,7 @@ class IdentityHandler(BaseHandler):
 
         # if we have a rewrite rule set for the identity server,
         # apply it now.
-        if id_server in self.rewrite_identity_server_urls:
-            id_server = self.rewrite_identity_server_urls[id_server]
+        id_server = self.rewrite_id_server_url(id_server)
         try:
             data = yield self.http_client.post_json_get_json(
                 id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
@@ -636,7 +646,7 @@ class IdentityHandler(BaseHandler):
                 403, "Looking up third-party identifiers is denied from this server"
             )
 
-        target = self.rewrite_identity_server_urls.get(id_server, id_server)
+        target = self.rewrite_id_server_url(id_server)
 
         try:
             data = yield self.http_client.get_json(
@@ -678,7 +688,7 @@ class IdentityHandler(BaseHandler):
                 403, "Looking up third-party identifiers is denied from this server"
             )
 
-        target = self.rewrite_identity_server_urls.get(id_server, id_server)
+        target = self.rewrite_id_server_url(id_server)
 
         try:
             data = yield self.http_client.post_json_get_json(
@@ -711,7 +721,7 @@ class IdentityHandler(BaseHandler):
             str|None: the matrix ID of the 3pid, or None if it is not recognized.
         """
         # Rewrite id_server URL if necessary
-        id_server_url = self._get_id_server_target(id_server)
+        id_server_url = self.rewrite_id_server_url(id_server)
 
         if id_access_token is not None:
             try:
@@ -886,9 +896,7 @@ class IdentityHandler(BaseHandler):
             raise AuthError(401, "No signature from server %s" % (server_hostname,))
 
         for key_name, signature in data["signatures"][server_hostname].items():
-            target = self.rewrite_identity_server_urls.get(
-                server_hostname, server_hostname
-            )
+            target = self.rewrite_id_server_url(server_hostname)
 
             key_data = yield self.http_client.get_json(
                 "https://%s/_matrix/identity/api/v1/pubkey/%s" % (target, key_name)
@@ -908,21 +916,6 @@ class IdentityHandler(BaseHandler):
 
         raise AuthError(401, "No signature from server %s" % (server_hostname,))
 
-    def _get_id_server_target(self, id_server):
-        """Looks up an id_server's actual http endpoint
-
-        Args:
-            id_server (str): the server name to lookup.
-
-        Returns:
-            the http endpoint to connect to.
-        """
-        if id_server in self.rewrite_identity_server_urls:
-            return self.rewrite_identity_server_urls[id_server]
-
-        return id_server
-
-
     @defer.inlineCallbacks
     def ask_id_server_for_third_party_invite(
         self,
@@ -984,7 +977,7 @@ class IdentityHandler(BaseHandler):
         }
 
         # Rewrite the identity server URL if necessary
-        id_server = self._get_id_server_target(id_server)
+        id_server = self.rewrite_id_server_url(id_server)
 
         # Add the identity service access token to the JSON body and use the v2
         # Identity Service endpoints if id_access_token is present
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index decef944ff..2d9ff97324 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -75,7 +75,6 @@ class RoomMemberHandler(object):
         self.spam_checker = hs.get_spam_checker()
         self.third_party_event_rules = hs.get_third_party_event_rules()
         self._server_notices_mxid = self.config.server_notices_mxid
-        self.rewrite_identity_server_urls = self.config.rewrite_identity_server_urls
         self._enable_lookup = hs.config.enable_3pid_lookup
         self.allow_per_room_profiles = self.config.allow_per_room_profiles
         self.ratelimiter = Ratelimiter()