summary refs log tree commit diff
path: root/tests/handlers/test_oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_oidc.py')
-rw-r--r--tests/handlers/test_oidc.py59
1 files changed, 54 insertions, 5 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..89ec5fcb31 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -75,7 +75,17 @@ COMMON_CONFIG = {
 COOKIE_NAME = b"oidc_session"
 COOKIE_PATH = "/_synapse/oidc"
 
-MockedMappingProvider = Mock(OidcMappingProvider)
+
+class TestMappingProvider(OidcMappingProvider):
+    @staticmethod
+    def parse_config(config):
+        return
+
+    def get_remote_user_id(self, userinfo):
+        return userinfo["sub"]
+
+    async def map_user_attributes(self, userinfo, token):
+        return {"localpart": userinfo["username"], "display_name": None}
 
 
 def simple_async_mock(return_value=None, raises=None):
@@ -123,7 +133,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         oidc_config["issuer"] = ISSUER
         oidc_config["scopes"] = SCOPES
         oidc_config["user_mapping_provider"] = {
-            "module": __name__ + ".MockedMappingProvider"
+            "module": __name__ + ".TestMappingProvider",
         }
         config["oidc_config"] = oidc_config
 
@@ -374,12 +384,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
         self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
         self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(spec=["args", "getCookie", "addCookie"])
+        request = Mock(
+            spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+        )
 
         code = "code"
         state = "state"
         nonce = "nonce"
         client_redirect_url = "http://client/redirect"
+        user_agent = "Browser"
+        ip_address = "10.0.0.1"
         session = self.handler._generate_oidc_session_token(
             state=state,
             nonce=nonce,
@@ -392,6 +406,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
         request.args[b"code"] = [code.encode("utf-8")]
         request.args[b"state"] = [state.encode("utf-8")]
 
+        request.requestHeaders = Mock(spec=["getRawHeaders"])
+        request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+        request.getClientIP.return_value = ip_address
+
         yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
 
         self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +417,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+        self.handler._map_userinfo_to_user.assert_called_once_with(
+            userinfo, token, user_agent, ip_address
+        )
         self.handler._fetch_userinfo.assert_not_called()
         self.handler._render_error.assert_not_called()
 
@@ -431,7 +451,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_not_called()
-        self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+        self.handler._map_userinfo_to_user.assert_called_once_with(
+            userinfo, token, user_agent, ip_address
+        )
         self.handler._fetch_userinfo.assert_called_once_with(token)
         self.handler._render_error.assert_not_called()
 
@@ -568,3 +590,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
         with self.assertRaises(OidcError) as exc:
             yield defer.ensureDeferred(self.handler._exchange_code(code))
         self.assertEqual(exc.exception.error, "some_error")
+
+    def test_map_userinfo_to_user(self):
+        """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+        userinfo = {
+            "sub": "test_user",
+            "username": "test_user",
+        }
+        # The token doesn't matter with the default user mapping provider.
+        token = {}
+        mxid = self.get_success(
+            self.handler._map_userinfo_to_user(
+                userinfo, token, "user-agent", "10.10.10.10"
+            )
+        )
+        self.assertEqual(mxid, "@test_user:test")
+
+        # Some providers return an integer ID.
+        userinfo = {
+            "sub": 1234,
+            "username": "test_user_2",
+        }
+        mxid = self.get_success(
+            self.handler._map_userinfo_to_user(
+                userinfo, token, "user-agent", "10.10.10.10"
+            )
+        )
+        self.assertEqual(mxid, "@test_user_2:test")