summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/v1/admin.py11
-rw-r--r--synapse/rest/client/v1/login.py179
-rw-r--r--synapse/rest/media/v1/config_resource.py2
-rw-r--r--synapse/rest/saml2/__init__.py29
-rw-r--r--synapse/rest/saml2/metadata_resource.py36
-rw-r--r--synapse/rest/saml2/response_resource.py74
-rw-r--r--synapse/rest/well_known.py70
7 files changed, 303 insertions, 98 deletions
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 41534b8c2a..82433a2aa9 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -23,7 +23,7 @@ from six.moves import http_client
 
 from twisted.internet import defer
 
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership, UserTypes
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.servlet import (
     assert_params_in_dict,
@@ -158,6 +158,11 @@ class UserRegisterServlet(ClientV1RestServlet):
                 raise SynapseError(400, "Invalid password")
 
         admin = body.get("admin", None)
+        user_type = body.get("user_type", None)
+
+        if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
+            raise SynapseError(400, "Invalid user type")
+
         got_mac = body["mac"]
 
         want_mac = hmac.new(
@@ -171,6 +176,9 @@ class UserRegisterServlet(ClientV1RestServlet):
         want_mac.update(password)
         want_mac.update(b"\x00")
         want_mac.update(b"admin" if admin else b"notadmin")
+        if user_type:
+            want_mac.update(b"\x00")
+            want_mac.update(user_type.encode('utf8'))
         want_mac = want_mac.hexdigest()
 
         if not hmac.compare_digest(
@@ -189,6 +197,7 @@ class UserRegisterServlet(ClientV1RestServlet):
             password=body["password"],
             admin=bool(admin),
             generate_token=False,
+            user_type=user_type,
         )
 
         result = yield register._create_registration_details(user_id, body)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index f6b4a85e40..e9d3032498 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,17 +18,17 @@ import xml.etree.ElementTree as ET
 
 from six.moves import urllib
 
-from canonicaljson import json
-from saml2 import BINDING_HTTP_POST, config
-from saml2.client import Saml2Client
-
 from twisted.internet import defer
 from twisted.web.client import PartialDownloadError
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.http.server import finish_request
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.http.servlet import (
+    RestServlet,
+    parse_json_object_from_request,
+    parse_string,
+)
+from synapse.types import UserID, map_username_to_mxid_localpart
 from synapse.util.msisdn import phone_number_to_msisdn
 
 from .base import ClientV1RestServlet, client_path_patterns
@@ -81,7 +81,6 @@ def login_id_thirdparty_from_phone(identifier):
 
 class LoginRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/login$")
-    SAML2_TYPE = "m.login.saml2"
     CAS_TYPE = "m.login.cas"
     SSO_TYPE = "m.login.sso"
     TOKEN_TYPE = "m.login.token"
@@ -89,8 +88,6 @@ class LoginRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(LoginRestServlet, self).__init__(hs)
-        self.idp_redirect_url = hs.config.saml2_idp_redirect_url
-        self.saml2_enabled = hs.config.saml2_enabled
         self.jwt_enabled = hs.config.jwt_enabled
         self.jwt_secret = hs.config.jwt_secret
         self.jwt_algorithm = hs.config.jwt_algorithm
@@ -103,8 +100,6 @@ class LoginRestServlet(ClientV1RestServlet):
         flows = []
         if self.jwt_enabled:
             flows.append({"type": LoginRestServlet.JWT_TYPE})
-        if self.saml2_enabled:
-            flows.append({"type": LoginRestServlet.SAML2_TYPE})
         if self.cas_enabled:
             flows.append({"type": LoginRestServlet.SSO_TYPE})
 
@@ -134,18 +129,8 @@ class LoginRestServlet(ClientV1RestServlet):
     def on_POST(self, request):
         login_submission = parse_json_object_from_request(request)
         try:
-            if self.saml2_enabled and (login_submission["type"] ==
-                                       LoginRestServlet.SAML2_TYPE):
-                relay_state = ""
-                if "relay_state" in login_submission:
-                    relay_state = "&RelayState=" + urllib.parse.quote(
-                                  login_submission["relay_state"])
-                result = {
-                    "uri": "%s%s" % (self.idp_redirect_url, relay_state)
-                }
-                defer.returnValue((200, result))
-            elif self.jwt_enabled and (login_submission["type"] ==
-                                       LoginRestServlet.JWT_TYPE):
+            if self.jwt_enabled and (login_submission["type"] ==
+                                     LoginRestServlet.JWT_TYPE):
                 result = yield self.do_jwt_login(login_submission)
                 defer.returnValue(result)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
@@ -345,50 +330,6 @@ class LoginRestServlet(ClientV1RestServlet):
         )
 
 
-class SAML2RestServlet(ClientV1RestServlet):
-    PATTERNS = client_path_patterns("/login/saml2", releases=())
-
-    def __init__(self, hs):
-        super(SAML2RestServlet, self).__init__(hs)
-        self.sp_config = hs.config.saml2_config_path
-        self.handlers = hs.get_handlers()
-
-    @defer.inlineCallbacks
-    def on_POST(self, request):
-        saml2_auth = None
-        try:
-            conf = config.SPConfig()
-            conf.load_file(self.sp_config)
-            SP = Saml2Client(conf)
-            saml2_auth = SP.parse_authn_request_response(
-                request.args['SAMLResponse'][0], BINDING_HTTP_POST)
-        except Exception as e:        # Not authenticated
-            logger.exception(e)
-        if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
-            username = saml2_auth.name_id.text
-            handler = self.handlers.registration_handler
-            (user_id, token) = yield handler.register_saml2(username)
-            # Forward to the RelayState callback along with ava
-            if 'RelayState' in request.args:
-                request.redirect(urllib.parse.unquote(
-                                 request.args['RelayState'][0]) +
-                                 '?status=authenticated&access_token=' +
-                                 token + '&user_id=' + user_id + '&ava=' +
-                                 urllib.quote(json.dumps(saml2_auth.ava)))
-                finish_request(request)
-                defer.returnValue(None)
-            defer.returnValue((200, {"status": "authenticated",
-                                     "user_id": user_id, "token": token,
-                                     "ava": saml2_auth.ava}))
-        elif 'RelayState' in request.args:
-            request.redirect(urllib.parse.unquote(
-                             request.args['RelayState'][0]) +
-                             '?status=not_authenticated')
-            finish_request(request)
-            defer.returnValue(None)
-        defer.returnValue((200, {"status": "not_authenticated"}))
-
-
 class CasRedirectServlet(RestServlet):
     PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")
 
@@ -421,17 +362,15 @@ class CasTicketServlet(ClientV1RestServlet):
         self.cas_server_url = hs.config.cas_server_url
         self.cas_service_url = hs.config.cas_service_url
         self.cas_required_attributes = hs.config.cas_required_attributes
-        self.auth_handler = hs.get_auth_handler()
-        self.handlers = hs.get_handlers()
-        self.macaroon_gen = hs.get_macaroon_generator()
+        self._sso_auth_handler = SSOAuthHandler(hs)
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        client_redirect_url = request.args[b"redirectUrl"][0]
+        client_redirect_url = parse_string(request, "redirectUrl", required=True)
         http_client = self.hs.get_simple_http_client()
         uri = self.cas_server_url + "/proxyValidate"
         args = {
-            "ticket": request.args[b"ticket"][0].decode('ascii'),
+            "ticket": parse_string(request, "ticket", required=True),
             "service": self.cas_service_url
         }
         try:
@@ -443,7 +382,6 @@ class CasTicketServlet(ClientV1RestServlet):
         result = yield self.handle_cas_response(request, body, client_redirect_url)
         defer.returnValue(result)
 
-    @defer.inlineCallbacks
     def handle_cas_response(self, request, cas_response_body, client_redirect_url):
         user, attributes = self.parse_cas_response(cas_response_body)
 
@@ -459,28 +397,9 @@ class CasTicketServlet(ClientV1RestServlet):
                 if required_value != actual_value:
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
-        user_id = UserID(user, self.hs.hostname).to_string()
-        auth_handler = self.auth_handler
-        registered_user_id = yield auth_handler.check_user_exists(user_id)
-        if not registered_user_id:
-            registered_user_id, _ = (
-                yield self.handlers.registration_handler.register(localpart=user)
-            )
-
-        login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id
+        return self._sso_auth_handler.on_successful_auth(
+            user, request, client_redirect_url,
         )
-        redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
-                                                            login_token)
-        request.redirect(redirect_url)
-        finish_request(request)
-
-    def add_login_token_to_redirect_url(self, url, token):
-        url_parts = list(urllib.parse.urlparse(url))
-        query = dict(urllib.parse.parse_qsl(url_parts[4]))
-        query.update({"loginToken": token})
-        url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
-        return urllib.parse.urlunparse(url_parts)
 
     def parse_cas_response(self, cas_response_body):
         user = None
@@ -515,10 +434,78 @@ class CasTicketServlet(ClientV1RestServlet):
         return user, attributes
 
 
+class SSOAuthHandler(object):
+    """
+    Utility class for Resources and Servlets which handle the response from a SSO
+    service
+
+    Args:
+        hs (synapse.server.HomeServer)
+    """
+    def __init__(self, hs):
+        self._hostname = hs.hostname
+        self._auth_handler = hs.get_auth_handler()
+        self._registration_handler = hs.get_handlers().registration_handler
+        self._macaroon_gen = hs.get_macaroon_generator()
+
+    @defer.inlineCallbacks
+    def on_successful_auth(
+        self, username, request, client_redirect_url,
+        user_display_name=None,
+    ):
+        """Called once the user has successfully authenticated with the SSO.
+
+        Registers the user if necessary, and then returns a redirect (with
+        a login token) to the client.
+
+        Args:
+            username (unicode|bytes): the remote user id. We'll map this onto
+                something sane for a MXID localpath.
+
+            request (SynapseRequest): the incoming request from the browser. We'll
+                respond to it with a redirect.
+
+            client_redirect_url (unicode): the redirect_url the client gave us when
+                it first started the process.
+
+            user_display_name (unicode|None): if set, and we have to register a new user,
+                we will set their displayname to this.
+
+        Returns:
+            Deferred[none]: Completes once we have handled the request.
+        """
+        localpart = map_username_to_mxid_localpart(username)
+        user_id = UserID(localpart, self._hostname).to_string()
+        registered_user_id = yield self._auth_handler.check_user_exists(user_id)
+        if not registered_user_id:
+            registered_user_id, _ = (
+                yield self._registration_handler.register(
+                    localpart=localpart,
+                    generate_token=False,
+                    default_display_name=user_display_name,
+                )
+            )
+
+        login_token = self._macaroon_gen.generate_short_term_login_token(
+            registered_user_id
+        )
+        redirect_url = self._add_login_token_to_redirect_url(
+            client_redirect_url, login_token
+        )
+        request.redirect(redirect_url)
+        finish_request(request)
+
+    @staticmethod
+    def _add_login_token_to_redirect_url(url, token):
+        url_parts = list(urllib.parse.urlparse(url))
+        query = dict(urllib.parse.parse_qsl(url_parts[4]))
+        query.update({"loginToken": token})
+        url_parts[4] = urllib.parse.urlencode(query)
+        return urllib.parse.urlunparse(url_parts)
+
+
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
-    if hs.config.saml2_enabled:
-        SAML2RestServlet(hs).register(http_server)
     if hs.config.cas_enabled:
         CasRedirectServlet(hs).register(http_server)
         CasTicketServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index d6605b6027..77316033f7 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -41,7 +41,7 @@ class MediaConfigResource(Resource):
     @defer.inlineCallbacks
     def _async_render_GET(self, request):
         yield self.auth.get_user_by_req(request)
-        respond_with_json(request, 200, self.limits_dict)
+        respond_with_json(request, 200, self.limits_dict, send_cors=True)
 
     def render_OPTIONS(self, request):
         respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/saml2/__init__.py b/synapse/rest/saml2/__init__.py
new file mode 100644
index 0000000000..68da37ca6a
--- /dev/null
+++ b/synapse/rest/saml2/__init__.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.web.resource import Resource
+
+from synapse.rest.saml2.metadata_resource import SAML2MetadataResource
+from synapse.rest.saml2.response_resource import SAML2ResponseResource
+
+logger = logging.getLogger(__name__)
+
+
+class SAML2Resource(Resource):
+    def __init__(self, hs):
+        Resource.__init__(self)
+        self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
+        self.putChild(b"authn_response", SAML2ResponseResource(hs))
diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/saml2/metadata_resource.py
new file mode 100644
index 0000000000..e8c680aeb4
--- /dev/null
+++ b/synapse/rest/saml2/metadata_resource.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import saml2.metadata
+
+from twisted.web.resource import Resource
+
+
+class SAML2MetadataResource(Resource):
+    """A Twisted web resource which renders the SAML metadata"""
+
+    isLeaf = 1
+
+    def __init__(self, hs):
+        Resource.__init__(self)
+        self.sp_config = hs.config.saml2_sp_config
+
+    def render_GET(self, request):
+        metadata_xml = saml2.metadata.create_metadata_string(
+            configfile=None, config=self.sp_config,
+        )
+        request.setHeader(b"Content-Type", b"text/xml; charset=utf-8")
+        return metadata_xml
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
new file mode 100644
index 0000000000..69fb77b322
--- /dev/null
+++ b/synapse/rest/saml2/response_resource.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import saml2
+from saml2.client import Saml2Client
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+from synapse.api.errors import CodeMessageException
+from synapse.http.server import wrap_html_request_handler
+from synapse.http.servlet import parse_string
+from synapse.rest.client.v1.login import SSOAuthHandler
+
+logger = logging.getLogger(__name__)
+
+
+class SAML2ResponseResource(Resource):
+    """A Twisted web resource which handles the SAML response"""
+
+    isLeaf = 1
+
+    def __init__(self, hs):
+        Resource.__init__(self)
+
+        self._saml_client = Saml2Client(hs.config.saml2_sp_config)
+        self._sso_auth_handler = SSOAuthHandler(hs)
+
+    def render_POST(self, request):
+        self._async_render_POST(request)
+        return NOT_DONE_YET
+
+    @wrap_html_request_handler
+    def _async_render_POST(self, request):
+        resp_bytes = parse_string(request, 'SAMLResponse', required=True)
+        relay_state = parse_string(request, 'RelayState', required=True)
+
+        try:
+            saml2_auth = self._saml_client.parse_authn_request_response(
+                resp_bytes, saml2.BINDING_HTTP_POST,
+            )
+        except Exception as e:
+            logger.warning("Exception parsing SAML2 response", exc_info=1)
+            raise CodeMessageException(
+                400, "Unable to parse SAML2 response: %s" % (e,),
+            )
+
+        if saml2_auth.not_signed:
+            raise CodeMessageException(400, "SAML2 response was not signed")
+
+        if "uid" not in saml2_auth.ava:
+            raise CodeMessageException(400, "uid not in SAML2 response")
+
+        username = saml2_auth.ava["uid"][0]
+
+        displayName = saml2_auth.ava.get("displayName", [None])[0]
+        return self._sso_auth_handler.on_successful_auth(
+            username, request, relay_state,
+            user_display_name=displayName,
+        )
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
new file mode 100644
index 0000000000..6e043d6162
--- /dev/null
+++ b/synapse/rest/well_known.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+
+from twisted.web.resource import Resource
+
+logger = logging.getLogger(__name__)
+
+
+class WellKnownBuilder(object):
+    """Utility to construct the well-known response
+
+    Args:
+        hs (synapse.server.HomeServer):
+    """
+    def __init__(self, hs):
+        self._config = hs.config
+
+    def get_well_known(self):
+        # if we don't have a public_base_url, we can't help much here.
+        if self._config.public_baseurl is None:
+            return None
+
+        result = {
+            "m.homeserver": {
+                "base_url": self._config.public_baseurl,
+            },
+        }
+
+        if self._config.default_identity_server:
+            result["m.identity_server"] = {
+                "base_url": self._config.default_identity_server,
+            }
+
+        return result
+
+
+class WellKnownResource(Resource):
+    """A Twisted web resource which renders the .well-known file"""
+
+    isLeaf = 1
+
+    def __init__(self, hs):
+        Resource.__init__(self)
+        self._well_known_builder = WellKnownBuilder(hs)
+
+    def render_GET(self, request):
+        r = self._well_known_builder.get_well_known()
+        if not r:
+            request.setResponseCode(404)
+            request.setHeader(b"Content-Type", b"text/plain")
+            return b'.well-known not available'
+
+        logger.error("returning: %s", r)
+        request.setHeader(b"Content-Type", b"application/json")
+        return json.dumps(r).encode("utf-8")