summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/auth.py142
1 files changed, 119 insertions, 23 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 82d458b424..200793b5ed 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
 from ._base import BaseHandler
 from synapse.api.constants import LoginType
 from synapse.types import UserID
-from synapse.api.errors import AuthError, LoginError, Codes
+from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
 from synapse.util.async import run_on_reactor
 
 from twisted.web.client import PartialDownloadError
@@ -49,6 +49,21 @@ class AuthHandler(BaseHandler):
         self.sessions = {}
         self.INVALID_TOKEN_HTTP_STATUS = 401
 
+        self.ldap_enabled = hs.config.ldap_enabled
+        self.ldap_server = hs.config.ldap_server
+        self.ldap_port = hs.config.ldap_port
+        self.ldap_tls = hs.config.ldap_tls
+        self.ldap_search_base = hs.config.ldap_search_base
+        self.ldap_search_property = hs.config.ldap_search_property
+        self.ldap_email_property = hs.config.ldap_email_property
+        self.ldap_full_name_property = hs.config.ldap_full_name_property
+
+        if self.ldap_enabled is True:
+            import ldap
+            logger.info("Import ldap version: %s", ldap.__version__)
+
+        self.hs = hs  # FIXME better possibility to access registrationHandler later?
+
     @defer.inlineCallbacks
     def check_auth(self, flows, clientdict, clientip):
         """
@@ -163,9 +178,13 @@ class AuthHandler(BaseHandler):
     def get_session_id(self, clientdict):
         """
         Gets the session ID for a client given the client dictionary
-        :param clientdict: The dictionary sent by the client in the request
-        :return: The string session ID the client sent. If the client did not
-                 send a session ID, returns None.
+
+        Args:
+            clientdict: The dictionary sent by the client in the request
+
+        Returns:
+            str|None: The string session ID the client sent. If the client did
+                not send a session ID, returns None.
         """
         sid = None
         if clientdict and 'auth' in clientdict:
@@ -179,9 +198,11 @@ class AuthHandler(BaseHandler):
         Store a key-value pair into the sessions data associated with this
         request. This data is stored server-side and cannot be modified by
         the client.
-        :param session_id: (string) The ID of this session as returned from check_auth
-        :param key: (string) The key to store the data under
-        :param value: (any) The data to store
+
+        Args:
+            session_id (string): The ID of this session as returned from check_auth
+            key (string): The key to store the data under
+            value (any): The data to store
         """
         sess = self._get_session_info(session_id)
         sess.setdefault('serverdict', {})[key] = value
@@ -190,9 +211,11 @@ class AuthHandler(BaseHandler):
     def get_session_data(self, session_id, key, default=None):
         """
         Retrieve data stored with set_session_data
-        :param session_id: (string) The ID of this session as returned from check_auth
-        :param key: (string) The key to store the data under
-        :param default: (any) Value to return if the key has not been set
+
+        Args:
+            session_id (string): The ID of this session as returned from check_auth
+            key (string): The key to store the data under
+            default (any): Value to return if the key has not been set
         """
         sess = self._get_session_info(session_id)
         return sess.setdefault('serverdict', {}).get(key, default)
@@ -207,8 +230,10 @@ class AuthHandler(BaseHandler):
         if not user_id.startswith('@'):
             user_id = UserID.create(user_id, self.hs.hostname).to_string()
 
-        user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
-        self._check_password(user_id, password, password_hash)
+        if not (yield self._check_password(user_id, password)):
+            logger.warn("Failed password login for user %s", user_id)
+            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
         defer.returnValue(user_id)
 
     @defer.inlineCallbacks
@@ -332,8 +357,10 @@ class AuthHandler(BaseHandler):
             StoreError if there was a problem storing the token.
             LoginError if there was an authentication problem.
         """
-        user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
-        self._check_password(user_id, password, password_hash)
+
+        if not (yield self._check_password(user_id, password)):
+            logger.warn("Failed password login for user %s", user_id)
+            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
 
         logger.info("Logging in user %s", user_id)
         access_token = yield self.issue_access_token(user_id)
@@ -399,11 +426,67 @@ class AuthHandler(BaseHandler):
         else:
             defer.returnValue(user_infos.popitem())
 
-    def _check_password(self, user_id, password, stored_hash):
-        """Checks that user_id has passed password, raises LoginError if not."""
-        if not self.validate_hash(password, stored_hash):
-            logger.warn("Failed password login for user %s", user_id)
-            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+    @defer.inlineCallbacks
+    def _check_password(self, user_id, password):
+        """
+        Returns:
+            True if the user_id successfully authenticated
+        """
+        valid_ldap = yield self._check_ldap_password(user_id, password)
+        if valid_ldap:
+            defer.returnValue(True)
+
+        valid_local_password = yield self._check_local_password(user_id, password)
+        if valid_local_password:
+            defer.returnValue(True)
+
+        defer.returnValue(False)
+
+    @defer.inlineCallbacks
+    def _check_local_password(self, user_id, password):
+        try:
+            user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
+            defer.returnValue(self.validate_hash(password, password_hash))
+        except LoginError:
+            defer.returnValue(False)
+
+    @defer.inlineCallbacks
+    def _check_ldap_password(self, user_id, password):
+        if not self.ldap_enabled:
+            logger.debug("LDAP not configured")
+            defer.returnValue(False)
+
+        import ldap
+
+        logger.info("Authenticating %s with LDAP" % user_id)
+        try:
+            ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
+            logger.debug("Connecting LDAP server at %s" % ldap_url)
+            l = ldap.initialize(ldap_url)
+            if self.ldap_tls:
+                logger.debug("Initiating TLS")
+                self._connection.start_tls_s()
+
+            local_name = UserID.from_string(user_id).localpart
+
+            dn = "%s=%s, %s" % (
+                self.ldap_search_property,
+                local_name,
+                self.ldap_search_base)
+            logger.debug("DN for LDAP authentication: %s" % dn)
+
+            l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
+
+            if not (yield self.does_user_exist(user_id)):
+                handler = self.hs.get_handlers().registration_handler
+                user_id, access_token = (
+                    yield handler.register(localpart=local_name)
+                )
+
+            defer.returnValue(True)
+        except ldap.LDAPError, e:
+            logger.warn("LDAP error: %s", e)
+            defer.returnValue(False)
 
     @defer.inlineCallbacks
     def issue_access_token(self, user_id):
@@ -438,14 +521,19 @@ class AuthHandler(BaseHandler):
         ))
         return m.serialize()
 
-    def generate_short_term_login_token(self, user_id):
+    def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = login")
         now = self.hs.get_clock().time_msec()
-        expiry = now + (2 * 60 * 1000)
+        expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
         return macaroon.serialize()
 
+    def generate_delete_pusher_token(self, user_id):
+        macaroon = self._generate_base_macaroon(user_id)
+        macaroon.add_first_party_caveat("type = delete_pusher")
+        return macaroon.serialize()
+
     def validate_short_term_login_token_and_get_user_id(self, login_token):
         try:
             macaroon = pymacaroons.Macaroon.deserialize(login_token)
@@ -480,7 +568,12 @@ class AuthHandler(BaseHandler):
 
         except_access_token_ids = [requester.access_token_id] if requester else []
 
-        yield self.store.user_set_password_hash(user_id, password_hash)
+        try:
+            yield self.store.user_set_password_hash(user_id, password_hash)
+        except StoreError as e:
+            if e.code == 404:
+                raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
+            raise e
         yield self.store.user_delete_access_tokens(
             user_id, except_access_token_ids
         )
@@ -532,4 +625,7 @@ class AuthHandler(BaseHandler):
         Returns:
             Whether self.hash(password) == stored_hash (bool).
         """
-        return bcrypt.hashpw(password, stored_hash) == stored_hash
+        if stored_hash:
+            return bcrypt.hashpw(password, stored_hash) == stored_hash
+        else:
+            return False