summary refs log tree commit diff
path: root/synapse/rest/client/v1/login.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r--synapse/rest/client/v1/login.py121
1 files changed, 52 insertions, 69 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 3b60728628..4efb679a04 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission):
     to a typed object.
     """
     if "user" in submission:
-        submission["identifier"] = {
-            "type": "m.id.user",
-            "user": submission["user"],
-        }
+        submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
         del submission["user"]
 
     if "medium" in submission and "address" in submission:
@@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier):
 
     msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
 
-    return {
-        "type": "m.id.thirdparty",
-        "medium": "msisdn",
-        "address": msisdn,
-    }
+    return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
 
 
 class LoginRestServlet(RestServlet):
@@ -120,9 +113,9 @@ class LoginRestServlet(RestServlet):
             # login flow types returned.
             flows.append({"type": LoginRestServlet.TOKEN_TYPE})
 
-        flows.extend((
-            {"type": t} for t in self.auth_handler.get_supported_login_types()
-        ))
+        flows.extend(
+            ({"type": t} for t in self.auth_handler.get_supported_login_types())
+        )
 
         return (200, {"flows": flows})
 
@@ -132,7 +125,8 @@ class LoginRestServlet(RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request):
         self._address_ratelimiter.ratelimit(
-            request.getClientIP(), time_now_s=self.hs.clock.time(),
+            request.getClientIP(),
+            time_now_s=self.hs.clock.time(),
             rate_hz=self.hs.config.rc_login_address.per_second,
             burst_count=self.hs.config.rc_login_address.burst_count,
             update=True,
@@ -140,8 +134,9 @@ class LoginRestServlet(RestServlet):
 
         login_submission = parse_json_object_from_request(request)
         try:
-            if self.jwt_enabled and (login_submission["type"] ==
-                                     LoginRestServlet.JWT_TYPE):
+            if self.jwt_enabled and (
+                login_submission["type"] == LoginRestServlet.JWT_TYPE
+            ):
                 result = yield self.do_jwt_login(login_submission)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                 result = yield self.do_token_login(login_submission)
@@ -170,10 +165,10 @@ class LoginRestServlet(RestServlet):
         # field)
         logger.info(
             "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
-            login_submission.get('identifier'),
-            login_submission.get('medium'),
-            login_submission.get('address'),
-            login_submission.get('user'),
+            login_submission.get("identifier"),
+            login_submission.get("medium"),
+            login_submission.get("address"),
+            login_submission.get("user"),
         )
         login_submission_legacy_convert(login_submission)
 
@@ -190,13 +185,13 @@ class LoginRestServlet(RestServlet):
 
         # convert threepid identifiers to user IDs
         if identifier["type"] == "m.id.thirdparty":
-            address = identifier.get('address')
-            medium = identifier.get('medium')
+            address = identifier.get("address")
+            medium = identifier.get("medium")
 
             if medium is None or address is None:
                 raise SynapseError(400, "Invalid thirdparty identifier")
 
-            if medium == 'email':
+            if medium == "email":
                 # For emails, transform the address to lowercase.
                 # We store all email addreses as lowercase in the DB.
                 # (See add_threepid in synapse/handlers/auth.py)
@@ -205,34 +200,28 @@ class LoginRestServlet(RestServlet):
             # Check for login providers that support 3pid login types
             canonical_user_id, callback_3pid = (
                 yield self.auth_handler.check_password_provider_3pid(
-                    medium,
-                    address,
-                    login_submission["password"],
+                    medium, address, login_submission["password"]
                 )
             )
             if canonical_user_id:
                 # Authentication through password provider and 3pid succeeded
                 result = yield self._register_device_with_callback(
-                    canonical_user_id, login_submission, callback_3pid,
+                    canonical_user_id, login_submission, callback_3pid
                 )
                 defer.returnValue(result)
 
             # No password providers were able to handle this 3pid
             # Check local store
             user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
-                medium, address,
+                medium, address
             )
             if not user_id:
                 logger.warn(
-                    "unknown 3pid identifier medium %s, address %r",
-                    medium, address,
+                    "unknown 3pid identifier medium %s, address %r", medium, address
                 )
                 raise LoginError(403, "", errcode=Codes.FORBIDDEN)
 
-            identifier = {
-                "type": "m.id.user",
-                "user": user_id,
-            }
+            identifier = {"type": "m.id.user", "user": user_id}
 
         # by this point, the identifier should be an m.id.user: if it's anything
         # else, we haven't understood it.
@@ -242,22 +231,16 @@ class LoginRestServlet(RestServlet):
             raise SynapseError(400, "User identifier is missing 'user' key")
 
         canonical_user_id, callback = yield self.auth_handler.validate_login(
-            identifier["user"],
-            login_submission,
+            identifier["user"], login_submission
         )
 
         result = yield self._register_device_with_callback(
-            canonical_user_id, login_submission, callback,
+            canonical_user_id, login_submission, callback
         )
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def _register_device_with_callback(
-        self,
-        user_id,
-        login_submission,
-        callback=None,
-    ):
+    def _register_device_with_callback(self, user_id, login_submission, callback=None):
         """ Registers a device with a given user_id. Optionally run a callback
         function after registration has completed.
 
@@ -273,7 +256,7 @@ class LoginRestServlet(RestServlet):
         device_id = login_submission.get("device_id")
         initial_display_name = login_submission.get("initial_device_display_name")
         device_id, access_token = yield self.registration_handler.register_device(
-            user_id, device_id, initial_display_name,
+            user_id, device_id, initial_display_name
         )
 
         result = {
@@ -290,7 +273,7 @@ class LoginRestServlet(RestServlet):
 
     @defer.inlineCallbacks
     def do_token_login(self, login_submission):
-        token = login_submission['token']
+        token = login_submission["token"]
         auth_handler = self.auth_handler
         user_id = (
             yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
@@ -299,7 +282,7 @@ class LoginRestServlet(RestServlet):
         device_id = login_submission.get("device_id")
         initial_display_name = login_submission.get("initial_device_display_name")
         device_id, access_token = yield self.registration_handler.register_device(
-            user_id, device_id, initial_display_name,
+            user_id, device_id, initial_display_name
         )
 
         result = {
@@ -316,15 +299,16 @@ class LoginRestServlet(RestServlet):
         token = login_submission.get("token", None)
         if token is None:
             raise LoginError(
-                401, "Token field for JWT is missing",
-                errcode=Codes.UNAUTHORIZED
+                401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
             )
 
         import jwt
         from jwt.exceptions import InvalidTokenError
 
         try:
-            payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
+            payload = jwt.decode(
+                token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+            )
         except jwt.ExpiredSignatureError:
             raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
         except InvalidTokenError:
@@ -342,7 +326,7 @@ class LoginRestServlet(RestServlet):
             device_id = login_submission.get("device_id")
             initial_display_name = login_submission.get("initial_device_display_name")
             device_id, access_token = yield self.registration_handler.register_device(
-                registered_user_id, device_id, initial_display_name,
+                registered_user_id, device_id, initial_display_name
             )
 
             result = {
@@ -358,7 +342,7 @@ class LoginRestServlet(RestServlet):
             device_id = login_submission.get("device_id")
             initial_display_name = login_submission.get("initial_device_display_name")
             device_id, access_token = yield self.registration_handler.register_device(
-                registered_user_id, device_id, initial_display_name,
+                registered_user_id, device_id, initial_display_name
             )
 
             result = {
@@ -375,21 +359,20 @@ class CasRedirectServlet(RestServlet):
 
     def __init__(self, hs):
         super(CasRedirectServlet, self).__init__()
-        self.cas_server_url = hs.config.cas_server_url.encode('ascii')
-        self.cas_service_url = hs.config.cas_service_url.encode('ascii')
+        self.cas_server_url = hs.config.cas_server_url.encode("ascii")
+        self.cas_service_url = hs.config.cas_service_url.encode("ascii")
 
     def on_GET(self, request):
         args = request.args
         if b"redirectUrl" not in args:
             return (400, "Redirect URL not specified for CAS auth")
-        client_redirect_url_param = urllib.parse.urlencode({
-            b"redirectUrl": args[b"redirectUrl"][0]
-        }).encode('ascii')
-        hs_redirect_url = (self.cas_service_url +
-                           b"/_matrix/client/r0/login/cas/ticket")
-        service_param = urllib.parse.urlencode({
-            b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
-        }).encode('ascii')
+        client_redirect_url_param = urllib.parse.urlencode(
+            {b"redirectUrl": args[b"redirectUrl"][0]}
+        ).encode("ascii")
+        hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
+        service_param = urllib.parse.urlencode(
+            {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
+        ).encode("ascii")
         request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
         finish_request(request)
 
@@ -411,7 +394,7 @@ class CasTicketServlet(RestServlet):
         uri = self.cas_server_url + "/proxyValidate"
         args = {
             "ticket": parse_string(request, "ticket", required=True),
-            "service": self.cas_service_url
+            "service": self.cas_service_url,
         }
         try:
             body = yield self._http_client.get_raw(uri, args)
@@ -438,7 +421,7 @@ class CasTicketServlet(RestServlet):
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
         return self._sso_auth_handler.on_successful_auth(
-            user, request, client_redirect_url,
+            user, request, client_redirect_url
         )
 
     def parse_cas_response(self, cas_response_body):
@@ -448,7 +431,7 @@ class CasTicketServlet(RestServlet):
             root = ET.fromstring(cas_response_body)
             if not root.tag.endswith("serviceResponse"):
                 raise Exception("root of CAS response is not serviceResponse")
-            success = (root[0].tag.endswith("authenticationSuccess"))
+            success = root[0].tag.endswith("authenticationSuccess")
             for child in root[0]:
                 if child.tag.endswith("user"):
                     user = child.text
@@ -466,11 +449,11 @@ class CasTicketServlet(RestServlet):
                 raise Exception("CAS response does not contain user")
         except Exception:
             logger.error("Error parsing CAS response", exc_info=1)
-            raise LoginError(401, "Invalid CAS response",
-                             errcode=Codes.UNAUTHORIZED)
+            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
         if not success:
-            raise LoginError(401, "Unsuccessful CAS response",
-                             errcode=Codes.UNAUTHORIZED)
+            raise LoginError(
+                401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
+            )
         return user, attributes
 
 
@@ -482,6 +465,7 @@ class SSOAuthHandler(object):
     Args:
         hs (synapse.server.HomeServer)
     """
+
     def __init__(self, hs):
         self._hostname = hs.hostname
         self._auth_handler = hs.get_auth_handler()
@@ -490,8 +474,7 @@ class SSOAuthHandler(object):
 
     @defer.inlineCallbacks
     def on_successful_auth(
-        self, username, request, client_redirect_url,
-        user_display_name=None,
+        self, username, request, client_redirect_url, user_display_name=None
     ):
         """Called once the user has successfully authenticated with the SSO.