diff --git a/changelog.d/4264.bugfix b/changelog.d/4264.bugfix
new file mode 100644
index 0000000000..b914026932
--- /dev/null
+++ b/changelog.d/4264.bugfix
@@ -0,0 +1 @@
+Fix CAS login when username is not valid in an MXID
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c6e89db4bc..2abd9af94f 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -563,10 +563,10 @@ class AuthHandler(BaseHandler):
insensitively, but return None if there are multiple inexact matches.
Args:
- (str) user_id: complete @user:id
+ (unicode|bytes) user_id: complete @user:id
Returns:
- defer.Deferred: (str) canonical_user_id, or None if zero or
+ defer.Deferred: (unicode) canonical_user_id, or None if zero or
multiple matches
"""
res = yield self._find_user_id_and_pwd_hash(user_id)
@@ -954,6 +954,15 @@ class MacaroonGenerator(object):
return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
+ """
+
+ Args:
+ user_id (unicode):
+ duration_in_ms (int):
+
+ Returns:
+ unicode
+ """
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 011e84e8b1..b7c5b58b01 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -23,8 +23,12 @@ from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.http.server import finish_request
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.http.servlet import (
+ RestServlet,
+ parse_json_object_from_request,
+ parse_string,
+)
+from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
@@ -358,17 +362,15 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
- self.auth_handler = hs.get_auth_handler()
- self.handlers = hs.get_handlers()
- self.macaroon_gen = hs.get_macaroon_generator()
+ self._sso_auth_handler = SSOAuthHandler(hs)
@defer.inlineCallbacks
def on_GET(self, request):
- client_redirect_url = request.args[b"redirectUrl"][0]
+ client_redirect_url = parse_string(request, "redirectUrl", required=True)
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
- "ticket": request.args[b"ticket"][0].decode('ascii'),
+ "ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url
}
try:
@@ -380,7 +382,6 @@ class CasTicketServlet(ClientV1RestServlet):
result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result)
- @defer.inlineCallbacks
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
@@ -396,28 +397,9 @@ class CasTicketServlet(ClientV1RestServlet):
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
- user_id = UserID(user, self.hs.hostname).to_string()
- auth_handler = self.auth_handler
- registered_user_id = yield auth_handler.check_user_exists(user_id)
- if not registered_user_id:
- registered_user_id, _ = (
- yield self.handlers.registration_handler.register(localpart=user)
- )
-
- login_token = self.macaroon_gen.generate_short_term_login_token(
- registered_user_id
+ return self._sso_auth_handler.on_successful_auth(
+ user, request, client_redirect_url,
)
- redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
- login_token)
- request.redirect(redirect_url)
- finish_request(request)
-
- def add_login_token_to_redirect_url(self, url, token):
- url_parts = list(urllib.parse.urlparse(url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({"loginToken": token})
- url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
- return urllib.parse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
user = None
@@ -452,6 +434,71 @@ class CasTicketServlet(ClientV1RestServlet):
return user, attributes
+class SSOAuthHandler(object):
+ """
+ Utility class for Resources and Servlets which handle the response from a SSO
+ service
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
+ def __init__(self, hs):
+ self._hostname = hs.hostname
+ self._auth_handler = hs.get_auth_handler()
+ self._registration_handler = hs.get_handlers().registration_handler
+ self._macaroon_gen = hs.get_macaroon_generator()
+
+ @defer.inlineCallbacks
+ def on_successful_auth(
+ self, username, request, client_redirect_url,
+ ):
+ """Called once the user has successfully authenticated with the SSO.
+
+ Registers the user if necessary, and then returns a redirect (with
+ a login token) to the client.
+
+ Args:
+ username (unicode|bytes): the remote user id. We'll map this onto
+ something sane for a MXID localpath.
+
+ request (SynapseRequest): the incoming request from the browser. We'll
+ respond to it with a redirect.
+
+ client_redirect_url (unicode): the redirect_url the client gave us when
+ it first started the process.
+
+ Returns:
+ Deferred[none]: Completes once we have handled the request.
+ """
+ localpart = map_username_to_mxid_localpart(username)
+ user_id = UserID(localpart, self._hostname).to_string()
+ registered_user_id = yield self._auth_handler.check_user_exists(user_id)
+ if not registered_user_id:
+ registered_user_id, _ = (
+ yield self._registration_handler.register(
+ localpart=localpart,
+ generate_token=False,
+ )
+ )
+
+ login_token = self._macaroon_gen.generate_short_term_login_token(
+ registered_user_id
+ )
+ redirect_url = self._add_login_token_to_redirect_url(
+ client_redirect_url, login_token
+ )
+ request.redirect(redirect_url)
+ finish_request(request)
+
+ @staticmethod
+ def _add_login_token_to_redirect_url(url, token):
+ url_parts = list(urllib.parse.urlparse(url))
+ query = dict(urllib.parse.parse_qsl(url_parts[4]))
+ query.update({"loginToken": token})
+ url_parts[4] = urllib.parse.urlencode(query)
+ return urllib.parse.urlunparse(url_parts)
+
+
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.cas_enabled:
diff --git a/synapse/types.py b/synapse/types.py
index 41afb27a74..d8cb64addb 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
import string
from collections import namedtuple
@@ -228,6 +229,71 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart)
+UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
+
+# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
+# localpart.
+#
+# It works by:
+# * building a string containing the allowed characters (excluding '=')
+# * escaping every special character with a backslash (to stop '-' being interpreted as a
+# range operator)
+# * wrapping it in a '[^...]' regex
+# * converting the whole lot to a 'bytes' sequence, so that we can use it to match
+# bytes rather than strings
+#
+NON_MXID_CHARACTER_PATTERN = re.compile(
+ ("[^%s]" % (
+ re.escape("".join(mxid_localpart_allowed_characters - {"="}),),
+ )).encode("ascii"),
+)
+
+
+def map_username_to_mxid_localpart(username, case_sensitive=False):
+ """Map a username onto a string suitable for a MXID
+
+ This follows the algorithm laid out at
+ https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
+
+ Args:
+ username (unicode|bytes): username to be mapped
+ case_sensitive (bool): true if TEST and test should be mapped
+ onto different mxids
+
+ Returns:
+ unicode: string suitable for a mxid localpart
+ """
+ if not isinstance(username, bytes):
+ username = username.encode('utf-8')
+
+ # first we sort out upper-case characters
+ if case_sensitive:
+ def f1(m):
+ return b"_" + m.group().lower()
+
+ username = UPPER_CASE_PATTERN.sub(f1, username)
+ else:
+ username = username.lower()
+
+ # then we sort out non-ascii characters
+ def f2(m):
+ g = m.group()[0]
+ if isinstance(g, str):
+ # on python 2, we need to do a ord(). On python 3, the
+ # byte itself will do.
+ g = ord(g)
+ return b"=%02x" % (g,)
+
+ username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)
+
+ # we also do the =-escaping to mxids starting with an underscore.
+ username = re.sub(b'^_', b'=5f', username)
+
+ # we should now only have ascii bytes left, so can decode back to a
+ # unicode.
+ return username.decode('ascii')
+
+
class StreamToken(
namedtuple("Token", (
"room_key",
diff --git a/tests/test_types.py b/tests/test_types.py
index 0f5c8bfaf9..d314a7ff58 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -14,7 +14,7 @@
# limitations under the License.
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID
+from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
from tests import unittest
from tests.utils import TestHomeServer
@@ -79,3 +79,32 @@ class GroupIDTestCase(unittest.TestCase):
except SynapseError as exc:
self.assertEqual(400, exc.code)
self.assertEqual("M_UNKNOWN", exc.errcode)
+
+
+class MapUsernameTestCase(unittest.TestCase):
+ def testPassThrough(self):
+ self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
+
+ def testUpperCase(self):
+ self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
+ self.assertEqual(
+ map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
+ "t_e_s_t__1234",
+ )
+
+ def testSymbols(self):
+ self.assertEqual(
+ map_username_to_mxid_localpart("test=$?_1234"),
+ "test=3d=24=3f_1234",
+ )
+
+ def testLeadingUnderscore(self):
+ self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
+
+ def testNonAscii(self):
+ # this should work with either a unicode or a bytes
+ self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast")
+ self.assertEqual(
+ map_username_to_mxid_localpart(u'têst'.encode('utf-8')),
+ "t=c3=aast",
+ )
|