diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 630e6da808..b4fa02acc4 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
from urllib.parse import parse_qs, urlparse
@@ -24,12 +23,8 @@ import pymacaroons
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
-from synapse.handlers.oidc_handler import (
- MappingException,
- OidcError,
- OidcHandler,
- OidcMappingProvider,
-)
+from synapse.handlers.oidc_handler import OidcError, OidcHandler, OidcMappingProvider
+from synapse.handlers.sso import MappingException
from synapse.types import UserID
from tests.unittest import HomeserverTestCase, override_config
@@ -132,14 +127,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
config = self.default_config()
config["public_baseurl"] = BASE_URL
- oidc_config = {}
- oidc_config["enabled"] = True
- oidc_config["client_id"] = CLIENT_ID
- oidc_config["client_secret"] = CLIENT_SECRET
- oidc_config["issuer"] = ISSUER
- oidc_config["scopes"] = SCOPES
- oidc_config["user_mapping_provider"] = {
- "module": __name__ + ".TestMappingProvider",
+ oidc_config = {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "issuer": ISSUER,
+ "scopes": SCOPES,
+ "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# Update this config with what's in the default config so that
@@ -705,13 +699,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_map_userinfo_to_existing_user(self):
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastore()
- user4 = UserID.from_string("@test_user_4:test")
+ user = UserID.from_string("@test_user:test")
self.get_success(
- store.register_user(user_id=user4.to_string(), password_hash=None)
+ store.register_user(user_id=user.to_string(), password_hash=None)
)
userinfo = {
- "sub": "test4",
- "username": "test_user_4",
+ "sub": "test",
+ "username": "test_user",
}
token = {}
mxid = self.get_success(
@@ -719,4 +713,59 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, "user-agent", "10.10.10.10"
)
)
- self.assertEqual(mxid, "@test_user_4:test")
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Register some non-exact matching cases.
+ user2 = UserID.from_string("@TEST_user_2:test")
+ self.get_success(
+ store.register_user(user_id=user2.to_string(), password_hash=None)
+ )
+ user2_caps = UserID.from_string("@test_USER_2:test")
+ self.get_success(
+ store.register_user(user_id=user2_caps.to_string(), password_hash=None)
+ )
+
+ # Attempting to login without matching a name exactly is an error.
+ userinfo = {
+ "sub": "test2",
+ "username": "TEST_USER_2",
+ }
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertTrue(
+ str(e.value).startswith(
+ "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
+ )
+ )
+
+ # Logging in when matching a name exactly should work.
+ user2 = UserID.from_string("@TEST_USER_2:test")
+ self.get_success(
+ store.register_user(user_id=user2.to_string(), password_hash=None)
+ )
+
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@TEST_USER_2:test")
+
+ def test_map_userinfo_to_invalid_localpart(self):
+ """If the mapping provider generates an invalid localpart it should be rejected."""
+ userinfo = {
+ "sub": "test2",
+ "username": "föö",
+ }
+ token = {}
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(str(e.value), "localpart is invalid: föö")
|