summary refs log tree commit diff
path: root/synapse/rest/client/v1
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1')
-rw-r--r--synapse/rest/client/v1/login.py96
-rw-r--r--synapse/rest/client/v1/room.py2
2 files changed, 90 insertions, 8 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index b2257b749d..2444f27366 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -20,14 +20,32 @@ from synapse.types import UserID
 from base import ClientV1RestServlet, client_path_pattern
 
 import simplejson as json
+import urllib
+
+import logging
+from saml2 import BINDING_HTTP_POST
+from saml2 import config
+from saml2.client import Saml2Client
+
+
+logger = logging.getLogger(__name__)
 
 
 class LoginRestServlet(ClientV1RestServlet):
     PATTERN = client_path_pattern("/login$")
     PASS_TYPE = "m.login.password"
+    SAML2_TYPE = "m.login.saml2"
+
+    def __init__(self, hs):
+        super(LoginRestServlet, self).__init__(hs)
+        self.idp_redirect_url = hs.config.saml2_idp_redirect_url
+        self.saml2_enabled = hs.config.saml2_enabled
 
     def on_GET(self, request):
-        return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]})
+        flows = [{"type": LoginRestServlet.PASS_TYPE}]
+        if self.saml2_enabled:
+            flows.append({"type": LoginRestServlet.SAML2_TYPE})
+        return (200, {"flows": flows})
 
     def on_OPTIONS(self, request):
         return (200, {})
@@ -39,6 +57,16 @@ class LoginRestServlet(ClientV1RestServlet):
             if login_submission["type"] == LoginRestServlet.PASS_TYPE:
                 result = yield self.do_password_login(login_submission)
                 defer.returnValue(result)
+            elif self.saml2_enabled and (login_submission["type"] ==
+                                         LoginRestServlet.SAML2_TYPE):
+                relay_state = ""
+                if "relay_state" in login_submission:
+                    relay_state = "&RelayState="+urllib.quote(
+                                  login_submission["relay_state"])
+                result = {
+                    "uri": "%s%s" % (self.idp_redirect_url, relay_state)
+                }
+                defer.returnValue((200, result))
             else:
                 raise SynapseError(400, "Bad login type.")
         except KeyError:
@@ -46,17 +74,24 @@ class LoginRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def do_password_login(self, login_submission):
-        if not login_submission["user"].startswith('@'):
-            login_submission["user"] = UserID.create(
-                login_submission["user"], self.hs.hostname).to_string()
+        if 'medium' in login_submission and 'address' in login_submission:
+            user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+                login_submission['medium'], login_submission['address']
+            )
+        else:
+            user_id = login_submission['user']
+
+        if not user_id.startswith('@'):
+            user_id = UserID.create(
+                user_id, self.hs.hostname
+            ).to_string()
 
-        handler = self.handlers.login_handler
-        token = yield handler.login(
-            user=login_submission["user"],
+        user_id, token = yield self.handlers.auth_handler.login_with_password(
+            user_id=user_id,
             password=login_submission["password"])
 
         result = {
-            "user_id": login_submission["user"],  # may have changed
+            "user_id": user_id,  # may have changed
             "access_token": token,
             "home_server": self.hs.hostname,
         }
@@ -94,6 +129,49 @@ class PasswordResetRestServlet(ClientV1RestServlet):
             )
 
 
+class SAML2RestServlet(ClientV1RestServlet):
+    PATTERN = client_path_pattern("/login/saml2")
+
+    def __init__(self, hs):
+        super(SAML2RestServlet, self).__init__(hs)
+        self.sp_config = hs.config.saml2_config_path
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        saml2_auth = None
+        try:
+            conf = config.SPConfig()
+            conf.load_file(self.sp_config)
+            SP = Saml2Client(conf)
+            saml2_auth = SP.parse_authn_request_response(
+                request.args['SAMLResponse'][0], BINDING_HTTP_POST)
+        except Exception, e:        # Not authenticated
+            logger.exception(e)
+        if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
+            username = saml2_auth.name_id.text
+            handler = self.handlers.registration_handler
+            (user_id, token) = yield handler.register_saml2(username)
+            # Forward to the RelayState callback along with ava
+            if 'RelayState' in request.args:
+                request.redirect(urllib.unquote(
+                                 request.args['RelayState'][0]) +
+                                 '?status=authenticated&access_token=' +
+                                 token + '&user_id=' + user_id + '&ava=' +
+                                 urllib.quote(json.dumps(saml2_auth.ava)))
+                request.finish()
+                defer.returnValue(None)
+            defer.returnValue((200, {"status": "authenticated",
+                                     "user_id": user_id, "token": token,
+                                     "ava": saml2_auth.ava}))
+        elif 'RelayState' in request.args:
+            request.redirect(urllib.unquote(
+                             request.args['RelayState'][0]) +
+                             '?status=not_authenticated')
+            request.finish()
+            defer.returnValue(None)
+        defer.returnValue((200, {"status": "not_authenticated"}))
+
+
 def _parse_json(request):
     try:
         content = json.loads(request.content.read())
@@ -106,4 +184,6 @@ def _parse_json(request):
 
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    if hs.config.saml2_enabled:
+        SAML2RestServlet(hs).register(http_server)
     # TODO PasswordResetRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 0346afb1b4..b4a70cba99 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -412,6 +412,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
             if "user_id" not in content:
                 raise SynapseError(400, "Missing user_id key.")
             state_key = content["user_id"]
+            # make sure it looks like a user ID; it'll throw if it's invalid.
+            UserID.from_string(state_key)
 
             if membership_action == "kick":
                 membership_action = "leave"