diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 1a886cbbbf..a31d277935 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission):
to a typed object.
"""
if "user" in submission:
- submission["identifier"] = {
- "type": "m.id.user",
- "user": submission["user"],
- }
+ submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
del submission["user"]
if "medium" in submission and "address" in submission:
@@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier):
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
- return {
- "type": "m.id.thirdparty",
- "medium": "msisdn",
- "address": msisdn,
- }
+ return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
class LoginRestServlet(RestServlet):
@@ -124,9 +117,9 @@ class LoginRestServlet(RestServlet):
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- flows.extend((
- {"type": t} for t in self.auth_handler.get_supported_login_types()
- ))
+ flows.extend(
+ ({"type": t} for t in self.auth_handler.get_supported_login_types())
+ )
return (200, {"flows": flows})
@@ -136,7 +129,8 @@ class LoginRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
self._address_ratelimiter.ratelimit(
- request.getClientIP(), time_now_s=self.hs.clock.time(),
+ request.getClientIP(),
+ time_now_s=self.hs.clock.time(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
update=True,
@@ -144,8 +138,9 @@ class LoginRestServlet(RestServlet):
login_submission = parse_json_object_from_request(request)
try:
- if self.jwt_enabled and (login_submission["type"] ==
- LoginRestServlet.JWT_TYPE):
+ if self.jwt_enabled and (
+ login_submission["type"] == LoginRestServlet.JWT_TYPE
+ ):
result = yield self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
@@ -174,10 +169,10 @@ class LoginRestServlet(RestServlet):
# field)
logger.info(
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
- login_submission.get('identifier'),
- login_submission.get('medium'),
- login_submission.get('address'),
- login_submission.get('user'),
+ login_submission.get("identifier"),
+ login_submission.get("medium"),
+ login_submission.get("address"),
+ login_submission.get("user"),
)
login_submission_legacy_convert(login_submission)
@@ -194,13 +189,13 @@ class LoginRestServlet(RestServlet):
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
- address = identifier.get('address')
- medium = identifier.get('medium')
+ address = identifier.get("address")
+ medium = identifier.get("medium")
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
- if medium == 'email':
+ if medium == "email":
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
@@ -209,34 +204,28 @@ class LoginRestServlet(RestServlet):
# Check for login providers that support 3pid login types
canonical_user_id, callback_3pid = (
yield self.auth_handler.check_password_provider_3pid(
- medium,
- address,
- login_submission["password"],
+ medium, address, login_submission["password"]
)
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
result = yield self._register_device_with_callback(
- canonical_user_id, login_submission, callback_3pid,
+ canonical_user_id, login_submission, callback_3pid
)
defer.returnValue(result)
# No password providers were able to handle this 3pid
# Check local store
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- medium, address,
+ medium, address
)
if not user_id:
logger.warn(
- "unknown 3pid identifier medium %s, address %r",
- medium, address,
+ "unknown 3pid identifier medium %s, address %r", medium, address
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
- identifier = {
- "type": "m.id.user",
- "user": user_id,
- }
+ identifier = {"type": "m.id.user", "user": user_id}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
@@ -246,22 +235,16 @@ class LoginRestServlet(RestServlet):
raise SynapseError(400, "User identifier is missing 'user' key")
canonical_user_id, callback = yield self.auth_handler.validate_login(
- identifier["user"],
- login_submission,
+ identifier["user"], login_submission
)
result = yield self._register_device_with_callback(
- canonical_user_id, login_submission, callback,
+ canonical_user_id, login_submission, callback
)
defer.returnValue(result)
@defer.inlineCallbacks
- def _register_device_with_callback(
- self,
- user_id,
- login_submission,
- callback=None,
- ):
+ def _register_device_with_callback(self, user_id, login_submission, callback=None):
""" Registers a device with a given user_id. Optionally run a callback
function after registration has completed.
@@ -277,7 +260,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name,
+ user_id, device_id, initial_display_name
)
result = {
@@ -294,7 +277,7 @@ class LoginRestServlet(RestServlet):
@defer.inlineCallbacks
def do_token_login(self, login_submission):
- token = login_submission['token']
+ token = login_submission["token"]
auth_handler = self.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
@@ -303,7 +286,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name,
+ user_id, device_id, initial_display_name
)
result = {
@@ -320,15 +303,16 @@ class LoginRestServlet(RestServlet):
token = login_submission.get("token", None)
if token is None:
raise LoginError(
- 401, "Token field for JWT is missing",
- errcode=Codes.UNAUTHORIZED
+ 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
)
import jwt
from jwt.exceptions import InvalidTokenError
try:
- payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
+ payload = jwt.decode(
+ token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+ )
except jwt.ExpiredSignatureError:
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
except InvalidTokenError:
@@ -346,7 +330,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
- registered_user_id, device_id, initial_display_name,
+ registered_user_id, device_id, initial_display_name
)
result = {
@@ -362,7 +346,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
- registered_user_id, device_id, initial_display_name,
+ registered_user_id, device_id, initial_display_name
)
result = {
@@ -376,6 +360,7 @@ class LoginRestServlet(RestServlet):
class BaseSsoRedirectServlet(RestServlet):
"""Common base class for /login/sso/redirect impls"""
+
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def on_GET(self, request):
@@ -401,21 +386,20 @@ class BaseSsoRedirectServlet(RestServlet):
raise NotImplementedError()
-class CasRedirectServlet(RestServlet):
+class CasRedirectServlet(BaseSsoRedirectServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
- self.cas_server_url = hs.config.cas_server_url.encode('ascii')
- self.cas_service_url = hs.config.cas_service_url.encode('ascii')
+ self.cas_server_url = hs.config.cas_server_url.encode("ascii")
+ self.cas_service_url = hs.config.cas_service_url.encode("ascii")
def get_sso_url(self, client_redirect_url):
- client_redirect_url_param = urllib.parse.urlencode({
- b"redirectUrl": client_redirect_url
- }).encode('ascii')
- hs_redirect_url = (self.cas_service_url +
- b"/_matrix/client/r0/login/cas/ticket")
- service_param = urllib.parse.urlencode({
- b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
- }).encode('ascii')
+ client_redirect_url_param = urllib.parse.urlencode(
+ {b"redirectUrl": client_redirect_url}
+ ).encode("ascii")
+ hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
+ service_param = urllib.parse.urlencode(
+ {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
+ ).encode("ascii")
return b"%s/login?%s" % (self.cas_server_url, service_param)
@@ -436,7 +420,7 @@ class CasTicketServlet(RestServlet):
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
- "service": self.cas_service_url
+ "service": self.cas_service_url,
}
try:
body = yield self._http_client.get_raw(uri, args)
@@ -463,7 +447,7 @@ class CasTicketServlet(RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return self._sso_auth_handler.on_successful_auth(
- user, request, client_redirect_url,
+ user, request, client_redirect_url
)
def parse_cas_response(self, cas_response_body):
@@ -473,7 +457,7 @@ class CasTicketServlet(RestServlet):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
- success = (root[0].tag.endswith("authenticationSuccess"))
+ success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
@@ -491,11 +475,11 @@ class CasTicketServlet(RestServlet):
raise Exception("CAS response does not contain user")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
- raise LoginError(401, "Invalid CAS response",
- errcode=Codes.UNAUTHORIZED)
+ raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
- raise LoginError(401, "Unsuccessful CAS response",
- errcode=Codes.UNAUTHORIZED)
+ raise LoginError(
+ 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
+ )
return user, attributes
@@ -507,11 +491,11 @@ class SAMLRedirectServlet(BaseSsoRedirectServlet):
def get_sso_url(self, client_redirect_url):
reqid, info = self._saml_client.prepare_for_authenticate(
- relay_state=client_redirect_url,
+ relay_state=client_redirect_url
)
- for key, value in info['headers']:
- if key == 'Location':
+ for key, value in info["headers"]:
+ if key == "Location":
return value
# this shouldn't happen!
@@ -526,6 +510,7 @@ class SSOAuthHandler(object):
Args:
hs (synapse.server.HomeServer)
"""
+
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
@@ -534,8 +519,7 @@ class SSOAuthHandler(object):
@defer.inlineCallbacks
def on_successful_auth(
- self, username, request, client_redirect_url,
- user_display_name=None,
+ self, username, request, client_redirect_url, user_display_name=None
):
"""Called once the user has successfully authenticated with the SSO.
|