summary refs log tree commit diff
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-03-24 11:57:11 +0000
committerBrendan Abolivier <babolivier@matrix.org>2022-03-24 11:57:11 +0000
commitaf4da7e51931ef74ad7a63b8b81e4f9f95049230 (patch)
treee1d52d939484c1f9938bb3f724d3b9d16468ed74
parentAdd type hints to tests files. (#12256) (diff)
downloadsynapse-af4da7e51931ef74ad7a63b8b81e4f9f95049230.tar.xz
Add a configuration option for rewriting base URLs when interacting with ISs
-rw-r--r--docs/sample_config.yaml17
-rw-r--r--synapse/config/server.py25
-rw-r--r--synapse/handlers/identity.py64
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--tests/rest/client/test_identity.py100
5 files changed, 182 insertions, 26 deletions
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 36c6c56e58..9336c557b2 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -539,6 +539,23 @@ templates:
   #
   #custom_template_directory: /path/to/custom/templates/
 
+# Base URLs to substitute when making requests to identity servers from Synapse.
+# This can be useful if an identity server exists under a different name or
+# address within an internal network than on the Internet.
+#
+# The first half of each line is the domain or address on which the identity
+# server is publicly accessible (without a protocol scheme) and the second half
+# is the base URL (i.e. protocol scheme and domain or address) to use for this
+# identity server.
+#
+# This list does not need to be exhaustive: if Synapse needs to send a request to
+# an identity server that isn't in this list it will just use its public name or
+# address.
+#
+#rewrite_identity_server_base_urls:
+#   public.example.com: http://public.int.example.com
+#   vip.example.com: http://vip.int.example.com
+
 
 # Message retention policy at the server level.
 #
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 49cd0a4f19..ef853c12be 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -676,6 +676,14 @@ class ServerConfig(Config):
         ):
             raise ConfigError("'custom_template_directory' must be a string")
 
+        self.identity_server_rewrite_map: Dict[str, str] = (
+            config.get("rewrite_identity_server_base_urls") or {}
+        )
+        if not isinstance(self.identity_server_rewrite_map, dict):
+            raise ConfigError(
+                "'rewrite_identity_server_base_urls' must be a dictionary"
+            )
+
     def has_tls_listener(self) -> bool:
         return any(listener.tls for listener in self.listeners)
 
@@ -1230,6 +1238,23 @@ class ServerConfig(Config):
           # information about using custom templates.
           #
           #custom_template_directory: /path/to/custom/templates/
+
+        # Base URLs to substitute when making requests to identity servers from Synapse.
+        # This can be useful if an identity server exists under a different name or
+        # address within an internal network than on the Internet.
+        #
+        # The first half of each line is the domain or address on which the identity
+        # server is publicly accessible (without a protocol scheme) and the second half
+        # is the base URL (i.e. protocol scheme and domain or address) to use for this
+        # identity server.
+        #
+        # This list does not need to be exhaustive: if Synapse needs to send a request to
+        # an identity server that isn't in this list it will just use its public name or
+        # address.
+        #
+        #rewrite_identity_server_base_urls:
+        #   public.example.com: http://public.int.example.com
+        #   vip.example.com: http://vip.int.example.com
         """
             % locals()
         )
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 57c9fdfe62..90cf1d5c79 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -62,6 +62,7 @@ class IdentityHandler:
         self.hs = hs
 
         self._web_client_location = hs.config.email.invite_client_location
+        self._identity_server_rewrite_map = hs.config.server.identity_server_rewrite_map
 
         # Ratelimiters for `/requestToken` endpoints.
         self._3pid_validation_ratelimiter_ip = Ratelimiter(
@@ -131,6 +132,7 @@ class IdentityHandler:
 
         query_params = {"sid": session_id, "client_secret": client_secret}
 
+        id_server = self._rewrite_is_url(id_server)
         url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
 
         try:
@@ -200,11 +202,12 @@ class IdentityHandler:
         # Decide which API endpoint URLs to use
         headers = {}
         bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
+        base_url = self._rewrite_is_url(id_server)
         if use_v2:
-            bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+            bind_url = "%s/_matrix/identity/v2/3pid/bind" % (base_url,)
             headers["Authorization"] = create_id_access_token_header(id_access_token)  # type: ignore
         else:
-            bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+            bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (base_url,)
 
         try:
             # Use the blacklisting http client as this call is only to identity servers
@@ -300,7 +303,8 @@ class IdentityHandler:
                 "id_server must be a valid hostname with optional port and path components",
             )
 
-        url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
+        base_url = self._rewrite_is_url(id_server)
+        url = "%s/_matrix/identity/api/v1/3pid/unbind" % (base_url,)
         url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
 
         content = {
@@ -464,6 +468,8 @@ class IdentityHandler:
         if next_link:
             params["next_link"] = next_link
 
+        id_server = self._rewrite_is_url(id_server)
+
         try:
             data = await self.http_client.post_json_get_json(
                 id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
@@ -508,6 +514,8 @@ class IdentityHandler:
         if next_link:
             params["next_link"] = next_link
 
+        id_server = self._rewrite_is_url(id_server)
+
         try:
             data = await self.http_client.post_json_get_json(
                 id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
@@ -598,6 +606,8 @@ class IdentityHandler:
         """
         body = {"client_secret": client_secret, "sid": sid, "token": token}
 
+        id_server = self._rewrite_is_url(id_server)
+
         try:
             return await self.http_client.post_json_get_json(
                 id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
@@ -666,9 +676,10 @@ class IdentityHandler:
         Returns:
             the matrix ID of the 3pid, or None if it is not recognized.
         """
+        base_url = self._rewrite_is_url(id_server)
         try:
             data = await self.blacklisting_http_client.get_json(
-                "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
+                "%s/_matrix/identity/api/v1/lookup" % (base_url,),
                 {"medium": medium, "address": address},
             )
 
@@ -700,9 +711,10 @@ class IdentityHandler:
             the matrix ID of the 3pid, or None if it is not recognised.
         """
         # Check what hashing details are supported by this identity server
+        base_url = self._rewrite_is_url(id_server)
         try:
             hash_details = await self.blacklisting_http_client.get_json(
-                "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
+                "%s/_matrix/identity/v2/hash_details" % (base_url,),
                 {"access_token": id_access_token},
             )
         except RequestTimedOutError:
@@ -769,7 +781,7 @@ class IdentityHandler:
 
         try:
             lookup_results = await self.blacklisting_http_client.post_json_get_json(
-                "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
+                "%s/_matrix/identity/v2/lookup" % (base_url,),
                 {
                     "addresses": [lookup_value],
                     "algorithm": lookup_algorithm,
@@ -868,13 +880,11 @@ class IdentityHandler:
         # Add the identity service access token to the JSON body and use the v2
         # Identity Service endpoints if id_access_token is present
         data = None
-        base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
+        base_is_url = self._rewrite_is_url(id_server)
+        base_url = "%s/_matrix/identity" % (base_is_url,)
 
         if id_access_token:
-            key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
-                id_server_scheme,
-                id_server,
-            )
+            key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % (base_is_url,)
 
             # Attempt a v2 lookup
             url = base_url + "/v2/store-invite"
@@ -892,9 +902,8 @@ class IdentityHandler:
                     raise e
 
         if data is None:
-            key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
-                id_server_scheme,
-                id_server,
+            key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % (
+                base_is_url,
             )
             url = base_url + "/api/v1/store-invite"
 
@@ -946,6 +955,33 @@ class IdentityHandler:
         display_name = data["display_name"]
         return token, public_keys, fallback_public_key, display_name
 
+    def _rewrite_is_url(self, id_server: str) -> str:
+        """Replaces the base URL to an identity server using instructions from the config.
+
+        If no replacement is found for this URL, just returns the original URL with an
+        HTTPS protocol scheme appended to it if there isn't already one.
+
+        Args:
+            id_server: The identity server to optionally replace the base URL for. Might
+                include a protocol scheme.
+
+        Returns:
+            The base URL to use (with a protocol scheme). If no match can be found and
+            the provided identity server address already includes a protocol scheme, just
+            returns it as is. Otherwise, if no HTTP(S) protocol scheme can be found,
+            prepends an HTTPS protocol scheme to the address before returning it.
+        """
+        if id_server.startswith("https://"):
+            default_base_url = id_server
+            id_server = id_server.replace("https://", "")
+        elif id_server.startswith("http://"):
+            default_base_url = id_server
+            id_server = id_server.replace("https://", "")
+        else:
+            default_base_url = id_server_scheme + id_server
+
+        return self._identity_server_rewrite_map.get(id_server, default_base_url)
+
 
 def create_id_access_token_header(id_access_token: str) -> List[str]:
     """Create an Authorization header for passing to SimpleHttpClient as the header value
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 092e185c99..b88b46f0a4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -87,8 +87,6 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-id_server_scheme = "https://"
-
 FIVE_MINUTES_IN_MS = 5 * 60 * 1000
 
 
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 299b9d21e2..0c0c4e8498 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -12,17 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 from http import HTTPStatus
+from unittest.mock import Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
-from synapse.rest.client import login, room
+from synapse.http.client import SimpleHttpClient
+from synapse.rest.client import login, register, room
 from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class IdentityTestCase(unittest.HomeserverTestCase):
@@ -31,18 +33,55 @@ class IdentityTestCase(unittest.HomeserverTestCase):
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
         login.register_servlets,
+        register.register_servlets,
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         config = self.default_config()
-        config["enable_3pid_lookup"] = False
         self.hs = self.setup_test_homeserver(config=config)
 
         return self.hs
 
+    @unittest.override_config({"enable_3pid_lookup": False})
     def test_3pid_lookup_disabled(self) -> None:
-        self.hs.config.registration.enable_3pid_lookup = False
+        self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+
+        channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        room_id = channel.json_body["room_id"]
+
+        self._send_threepid_invite(
+            id_server="vip.example.com",
+            address="test@example.com",
+            room_id=room_id,
+            tok=tok,
+            expected_status=HTTPStatus.FORBIDDEN,
+        )
+
+    @unittest.override_config(
+        {
+            "rewrite_identity_server_base_urls": {
+                "vip.example.com": "http://vip-int.example.com",
+            }
+        }
+    )
+    def test_rewrite_is_base_url(self) -> None:
+        """Tests that base URLs for identity services are correctly rewritten."""
+        mock_client = Mock(spec=SimpleHttpClient)
+        mock_client.post_json_get_json = Mock(
+            return_value=make_awaitable(
+                {
+                    "token": "sometoken",
+                    "public_key": "somekey",
+                    "public_keys": [],
+                    "display_name": "foo",
+                }
+            )
+        )
+
+        self.hs.get_identity_handler().blacklisting_http_client = mock_client
 
         self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
@@ -51,14 +90,55 @@ class IdentityTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
         room_id = channel.json_body["room_id"]
 
+        # Send a 3PID invite, and check that the base URL for the identity server has been
+        # correctly rewritten.
+        self._send_threepid_invite(
+            id_server="vip.example.com",
+            address="test@example.com",
+            room_id=room_id,
+            tok=tok,
+            expected_status=HTTPStatus.OK,
+        )
+
+        mock_client.post_json_get_json.assert_called_once()
+        args = mock_client.post_json_get_json.call_args[0]
+
+        self.assertTrue(args[0].startswith("http://vip-int.example.com"))
+
+        # Send another 3PID invite, this time to an identity server that doesn't need
+        # rewriting, and check that the base URL hasn't been rewritten (apart from adding
+        # an HTTPS protocol scheme).
+        self._send_threepid_invite(
+            id_server="testis",
+            address="test@example.com",
+            room_id=room_id,
+            tok=tok,
+            expected_status=HTTPStatus.OK,
+        )
+
+        self.assertEqual(mock_client.post_json_get_json.call_count, 2)
+        args = mock_client.post_json_get_json.call_args[0]
+
+        self.assertTrue(args[0].startswith("https://testis"))
+
+    def _send_threepid_invite(
+        self, id_server: str, address: str, room_id: str, tok: str, expected_status: int
+    ) -> None:
+        """Try to send a 3PID invite into a room.
+
+        Args:
+            id_server: the identity server to use to store the invite.
+            address: the email address to send the invite to.
+            room_id: the room the invite is for.
+            tok: the access token to authenticate the request with.
+            expected_status: the expected HTTP status in the response to /invite.
+        """
         params = {
-            "id_server": "testis",
+            "id_server": id_server,
             "medium": "email",
-            "address": "test@example.com",
+            "address": address,
         }
-        request_data = json.dumps(params)
-        request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
         channel = self.make_request(
-            b"POST", request_url, request_data, access_token=tok
+            b"POST", "/rooms/%s/invite" % (room_id,), params, access_token=tok
         )
-        self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
+        self.assertEqual(channel.code, expected_status, channel.result)