diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/handlers/auth.py | 54 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 55 |
2 files changed, 48 insertions, 61 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d5d6faa85f..eeca820845 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -30,6 +30,8 @@ import simplejson import synapse.util.stringutils as stringutils +import ldap + logger = logging.getLogger(__name__) @@ -49,6 +51,15 @@ class AuthHandler(BaseHandler): self.sessions = {} self.INVALID_TOKEN_HTTP_STATUS = 401 + self.ldap_enabled = hs.config.ldap_enabled + self.ldap_server = hs.config.ldap_server + self.ldap_port = hs.config.ldap_port + self.ldap_search_base = hs.config.ldap_search_base + self.ldap_search_property = hs.config.ldap_search_property + self.ldap_email_property = hs.config.ldap_email_property + self.ldap_full_name_property = hs.config.ldap_full_name_property + + @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ @@ -215,8 +226,8 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) + self._check_password(user_id, password) + defer.returnValue(user_id) @defer.inlineCallbacks @@ -340,8 +351,8 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) + + self._check_password(user_id, password) logger.info("Logging in user %s", user_id) access_token = yield self.issue_access_token(user_id) @@ -407,12 +418,43 @@ class AuthHandler(BaseHandler): else: defer.returnValue(user_infos.popitem()) - def _check_password(self, user_id, password, stored_hash): + def _check_password(self, user_id, password): """Checks that user_id has passed password, raises LoginError if not.""" - if not self.validate_hash(password, stored_hash): + + if not (self._check_ldap_password(user_id, password) or self._check_local_password(user_id, password)): logger.warn("Failed password login for user %s", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) + def _check_local_password(self, user_id, password): + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + return not self.validate_hash(password, password_hash) + + def _check_ldap_password(self, user_id, password): + if not self.ldap_enabled: + return False + + logger.info("Authenticating %s with LDAP" % user_id) + try: + l = ldap.initialize("%s:%s" % (ldap_server, ldap_port)) + if self.ldap_tls: + logger.debug("Initiating TLS") + self._connection.start_tls_s() + + dn = "%s=%s, %s" % (ldap_search_property, user_id.localpart, ldap_search_base) + logger.debug("DN for LDAP authentication: %s" % dn) + + l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) + + if not self.does_user_exist(user_id): + user_id, access_token = ( + yield self.handlers.registration_handler.register(localpart=user_id.localpart) + ) + + return True + except ldap.LDAPError, e: + logger.info(e) + return False + @defer.inlineCallbacks def issue_access_token(self, user_id): access_token = self.generate_access_token(user_id) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 13720973be..da0fd2a8e0 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -36,8 +36,6 @@ import xml.etree.ElementTree as ET import jwt from jwt.exceptions import InvalidTokenError -import ldap - logger = logging.getLogger(__name__) @@ -49,7 +47,6 @@ class LoginRestServlet(ClientV1RestServlet): CAS_TYPE = "m.login.cas" TOKEN_TYPE = "m.login.token" JWT_TYPE = "m.login.jwt" - LDAP_TYPE = "m.login.ldap" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) @@ -59,13 +56,6 @@ class LoginRestServlet(ClientV1RestServlet): self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm - self.ldap_enabled = hs.config.ldap_enabled - self.ldap_server = hs.config.ldap_server - self.ldap_port = hs.config.ldap_port - self.ldap_search_base = hs.config.ldap_search_base - self.ldap_search_property = hs.config.ldap_search_property - self.ldap_email_property = hs.config.ldap_email_property - self.ldap_full_name_property = hs.config.ldap_full_name_property self.cas_enabled = hs.config.cas_enabled self.cas_server_url = hs.config.cas_server_url self.cas_required_attributes = hs.config.cas_required_attributes @@ -74,8 +64,6 @@ class LoginRestServlet(ClientV1RestServlet): def on_GET(self, request): flows = [] - if self.ldap_enabled: - flows.append({"type": LoginRestServlet.LDAP_TYPE}) if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) if self.saml2_enabled: @@ -176,49 +164,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - @defer.inlineCallbacks - def do_ldap_login(self, login_submission): - if 'medium' in login_submission and 'address' in login_submission: - user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - login_submission['medium'], login_submission['address'] - ) - if not user_id: - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - else: - user_id = login_submission['user'] - - if not user_id.startswith('@'): - user_id = UserID.create( - user_id, self.hs.hostname - ).to_string() - - # FIXME check against LDAP Server!! - - auth_handler = self.handlers.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "refresh_token": refresh_token, - "home_server": self.hs.hostname, - } - - else: - user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user_id.localpart) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } - - defer.returnValue((200, result)) - @defer.inlineCallbacks def do_token_login(self, login_submission): |