summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_oidc.py89
1 files changed, 69 insertions, 20 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 630e6da808..b4fa02acc4 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import json
 from urllib.parse import parse_qs, urlparse
 
@@ -24,12 +23,8 @@ import pymacaroons
 from twisted.python.failure import Failure
 from twisted.web._newclient import ResponseDone
 
-from synapse.handlers.oidc_handler import (
-    MappingException,
-    OidcError,
-    OidcHandler,
-    OidcMappingProvider,
-)
+from synapse.handlers.oidc_handler import OidcError, OidcHandler, OidcMappingProvider
+from synapse.handlers.sso import MappingException
 from synapse.types import UserID
 
 from tests.unittest import HomeserverTestCase, override_config
@@ -132,14 +127,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         config = self.default_config()
         config["public_baseurl"] = BASE_URL
-        oidc_config = {}
-        oidc_config["enabled"] = True
-        oidc_config["client_id"] = CLIENT_ID
-        oidc_config["client_secret"] = CLIENT_SECRET
-        oidc_config["issuer"] = ISSUER
-        oidc_config["scopes"] = SCOPES
-        oidc_config["user_mapping_provider"] = {
-            "module": __name__ + ".TestMappingProvider",
+        oidc_config = {
+            "enabled": True,
+            "client_id": CLIENT_ID,
+            "client_secret": CLIENT_SECRET,
+            "issuer": ISSUER,
+            "scopes": SCOPES,
+            "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
         }
 
         # Update this config with what's in the default config so that
@@ -705,13 +699,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
     def test_map_userinfo_to_existing_user(self):
         """Existing users can log in with OpenID Connect when allow_existing_users is True."""
         store = self.hs.get_datastore()
-        user4 = UserID.from_string("@test_user_4:test")
+        user = UserID.from_string("@test_user:test")
         self.get_success(
-            store.register_user(user_id=user4.to_string(), password_hash=None)
+            store.register_user(user_id=user.to_string(), password_hash=None)
         )
         userinfo = {
-            "sub": "test4",
-            "username": "test_user_4",
+            "sub": "test",
+            "username": "test_user",
         }
         token = {}
         mxid = self.get_success(
@@ -719,4 +713,59 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 userinfo, token, "user-agent", "10.10.10.10"
             )
         )
-        self.assertEqual(mxid, "@test_user_4:test")
+        self.assertEqual(mxid, "@test_user:test")
+
+        # Register some non-exact matching cases.
+        user2 = UserID.from_string("@TEST_user_2:test")
+        self.get_success(
+            store.register_user(user_id=user2.to_string(), password_hash=None)
+        )
+        user2_caps = UserID.from_string("@test_USER_2:test")
+        self.get_success(
+            store.register_user(user_id=user2_caps.to_string(), password_hash=None)
+        )
+
+        # Attempting to login without matching a name exactly is an error.
+        userinfo = {
+            "sub": "test2",
+            "username": "TEST_USER_2",
+        }
+        e = self.get_failure(
+            self.handler._map_userinfo_to_user(
+                userinfo, token, "user-agent", "10.10.10.10"
+            ),
+            MappingException,
+        )
+        self.assertTrue(
+            str(e.value).startswith(
+                "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
+            )
+        )
+
+        # Logging in when matching a name exactly should work.
+        user2 = UserID.from_string("@TEST_USER_2:test")
+        self.get_success(
+            store.register_user(user_id=user2.to_string(), password_hash=None)
+        )
+
+        mxid = self.get_success(
+            self.handler._map_userinfo_to_user(
+                userinfo, token, "user-agent", "10.10.10.10"
+            )
+        )
+        self.assertEqual(mxid, "@TEST_USER_2:test")
+
+    def test_map_userinfo_to_invalid_localpart(self):
+        """If the mapping provider generates an invalid localpart it should be rejected."""
+        userinfo = {
+            "sub": "test2",
+            "username": "föö",
+        }
+        token = {}
+        e = self.get_failure(
+            self.handler._map_userinfo_to_user(
+                userinfo, token, "user-agent", "10.10.10.10"
+            ),
+            MappingException,
+        )
+        self.assertEqual(str(e.value), "localpart is invalid: föö")