summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py209
1 files changed, 174 insertions, 35 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index b38f81e999..e259213a36 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -20,6 +20,7 @@ from synapse.api.constants import LoginType
 from synapse.types import UserID
 from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
 from synapse.util.async import run_on_reactor
+from synapse.config.ldap import LDAPMode
 
 from twisted.web.client import PartialDownloadError
 
@@ -28,6 +29,12 @@ import bcrypt
 import pymacaroons
 import simplejson
 
+try:
+    import ldap3
+except ImportError:
+    ldap3 = None
+    pass
+
 import synapse.util.stringutils as stringutils
 
 
@@ -50,17 +57,20 @@ class AuthHandler(BaseHandler):
         self.INVALID_TOKEN_HTTP_STATUS = 401
 
         self.ldap_enabled = hs.config.ldap_enabled
-        self.ldap_server = hs.config.ldap_server
-        self.ldap_port = hs.config.ldap_port
-        self.ldap_tls = hs.config.ldap_tls
-        self.ldap_search_base = hs.config.ldap_search_base
-        self.ldap_search_property = hs.config.ldap_search_property
-        self.ldap_email_property = hs.config.ldap_email_property
-        self.ldap_full_name_property = hs.config.ldap_full_name_property
-
-        if self.ldap_enabled is True:
-            import ldap
-            logger.info("Import ldap version: %s", ldap.__version__)
+        if self.ldap_enabled:
+            if not ldap3:
+                raise RuntimeError(
+                    'Missing ldap3 library. This is required for LDAP Authentication.'
+                )
+            self.ldap_mode = hs.config.ldap_mode
+            self.ldap_uri = hs.config.ldap_uri
+            self.ldap_start_tls = hs.config.ldap_start_tls
+            self.ldap_base = hs.config.ldap_base
+            self.ldap_filter = hs.config.ldap_filter
+            self.ldap_attributes = hs.config.ldap_attributes
+            if self.ldap_mode == LDAPMode.SEARCH:
+                self.ldap_bind_dn = hs.config.ldap_bind_dn
+                self.ldap_bind_password = hs.config.ldap_bind_password
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
 
@@ -452,40 +462,167 @@ class AuthHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _check_ldap_password(self, user_id, password):
-        if not self.ldap_enabled:
-            logger.debug("LDAP not configured")
+        """ Attempt to authenticate a user against an LDAP Server
+            and register an account if none exists.
+
+            Returns:
+                True if authentication against LDAP was successful
+        """
+
+        if not ldap3 or not self.ldap_enabled:
             defer.returnValue(False)
 
-        import ldap
+        if self.ldap_mode not in LDAPMode.LIST:
+            raise RuntimeError(
+                'Invalid ldap mode specified: {mode}'.format(
+                    mode=self.ldap_mode
+                )
+            )
 
-        logger.info("Authenticating %s with LDAP" % user_id)
         try:
-            ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
-            logger.debug("Connecting LDAP server at %s" % ldap_url)
-            l = ldap.initialize(ldap_url)
-            if self.ldap_tls:
-                logger.debug("Initiating TLS")
-                self._connection.start_tls_s()
+            server = ldap3.Server(self.ldap_uri)
+            logger.debug(
+                "Attempting ldap connection with %s",
+                self.ldap_uri
+            )
 
-            local_name = UserID.from_string(user_id).localpart
+            localpart = UserID.from_string(user_id).localpart
+            if self.ldap_mode == LDAPMode.SIMPLE:
+                # bind with the the local users ldap credentials
+                bind_dn = "{prop}={value},{base}".format(
+                    prop=self.ldap_attributes['uid'],
+                    value=localpart,
+                    base=self.ldap_base
+                )
+                conn = ldap3.Connection(server, bind_dn, password)
+                logger.debug(
+                    "Established ldap connection in simple mode: %s",
+                    conn
+                )
 
-            dn = "%s=%s, %s" % (
-                self.ldap_search_property,
-                local_name,
-                self.ldap_search_base)
-            logger.debug("DN for LDAP authentication: %s" % dn)
+                if self.ldap_start_tls:
+                    conn.start_tls()
+                    logger.debug(
+                        "Upgraded ldap connection in simple mode through StartTLS: %s",
+                        conn
+                    )
+
+                conn.bind()
+
+            elif self.ldap_mode == LDAPMode.SEARCH:
+                # connect with preconfigured credentials and search for local user
+                conn = ldap3.Connection(
+                    server,
+                    self.ldap_bind_dn,
+                    self.ldap_bind_password
+                )
+                logger.debug(
+                    "Established ldap connection in search mode: %s",
+                    conn
+                )
+
+                if self.ldap_start_tls:
+                    conn.start_tls()
+                    logger.debug(
+                        "Upgraded ldap connection in search mode through StartTLS: %s",
+                        conn
+                    )
 
-            l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
+                conn.bind()
 
+                # find matching dn
+                query = "({prop}={value})".format(
+                    prop=self.ldap_attributes['uid'],
+                    value=localpart
+                )
+                if self.ldap_filter:
+                    query = "(&{query}{filter})".format(
+                        query=query,
+                        filter=self.ldap_filter
+                    )
+                logger.debug("ldap search filter: %s", query)
+                result = conn.search(self.ldap_base, query)
+
+                if result and len(conn.response) == 1:
+                    # found exactly one result
+                    user_dn = conn.response[0]['dn']
+                    logger.debug('ldap search found dn: %s', user_dn)
+
+                    # unbind and reconnect, rebind with found dn
+                    conn.unbind()
+                    conn = ldap3.Connection(
+                        server,
+                        user_dn,
+                        password,
+                        auto_bind=True
+                    )
+                else:
+                    # found 0 or > 1 results, abort!
+                    logger.warn(
+                        "ldap search returned unexpected (%d!=1) amount of results",
+                        len(conn.response)
+                    )
+                    defer.returnValue(False)
+
+            logger.info(
+                "User authenticated against ldap server: %s",
+                conn
+            )
+
+            # check for existing account, if none exists, create one
             if not (yield self.does_user_exist(user_id)):
-                handler = self.hs.get_handlers().registration_handler
-                user_id, access_token = (
-                    yield handler.register(localpart=local_name)
+                # query user metadata for account creation
+                query = "({prop}={value})".format(
+                    prop=self.ldap_attributes['uid'],
+                    value=localpart
+                )
+
+                if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
+                    query = "(&{filter}{user_filter})".format(
+                        filter=query,
+                        user_filter=self.ldap_filter
+                    )
+                logger.debug("ldap registration filter: %s", query)
+
+                result = conn.search(
+                    search_base=self.ldap_base,
+                    search_filter=query,
+                    attributes=[
+                        self.ldap_attributes['name'],
+                        self.ldap_attributes['mail']
+                    ]
                 )
 
+                if len(conn.response) == 1:
+                    attrs = conn.response[0]['attributes']
+                    mail = attrs[self.ldap_attributes['mail']][0]
+                    name = attrs[self.ldap_attributes['name']][0]
+
+                    # create account
+                    registration_handler = self.hs.get_handlers().registration_handler
+                    user_id, access_token = (
+                        yield registration_handler.register(localpart=localpart)
+                    )
+
+                    # TODO: bind email, set displayname with data from ldap directory
+
+                    logger.info(
+                        "ldap registration successful: %d: %s (%s, %)",
+                        user_id,
+                        localpart,
+                        name,
+                        mail
+                    )
+                else:
+                    logger.warn(
+                        "ldap registration failed: unexpected (%d!=1) amount of results",
+                        len(result)
+                    )
+                    defer.returnValue(False)
+
             defer.returnValue(True)
-        except ldap.LDAPError, e:
-            logger.warn("LDAP error: %s", e)
+        except ldap3.core.exceptions.LDAPException as e:
+            logger.warn("Error during ldap authentication: %s", e)
             defer.returnValue(False)
 
     @defer.inlineCallbacks
@@ -613,7 +750,8 @@ class AuthHandler(BaseHandler):
         Returns:
             Hashed password (str).
         """
-        return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
+        return bcrypt.hashpw(password + self.hs.config.password_pepper,
+                             bcrypt.gensalt(self.bcrypt_rounds))
 
     def validate_hash(self, password, stored_hash):
         """Validates that self.hash(password) == stored_hash.
@@ -626,6 +764,7 @@ class AuthHandler(BaseHandler):
             Whether self.hash(password) == stored_hash (bool).
         """
         if stored_hash:
-            return bcrypt.hashpw(password, stored_hash.encode('utf-8')) == stored_hash
+            return bcrypt.hashpw(password + self.hs.config.password_pepper,
+                                 stored_hash.encode('utf-8')) == stored_hash
         else:
             return False