summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py149
1 files changed, 139 insertions, 10 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 602c5bcd89..be157e2bb7 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,14 +18,14 @@ from twisted.internet import defer
 from ._base import BaseHandler
 from synapse.api.constants import LoginType
 from synapse.types import UserID
-from synapse.api.errors import LoginError, Codes
-from synapse.http.client import SimpleHttpClient
+from synapse.api.errors import AuthError, LoginError, Codes
 from synapse.util.async import run_on_reactor
 
 from twisted.web.client import PartialDownloadError
 
 import logging
 import bcrypt
+import pymacaroons
 import simplejson
 
 import synapse.util.stringutils as stringutils
@@ -44,7 +44,9 @@ class AuthHandler(BaseHandler):
             LoginType.EMAIL_IDENTITY: self._check_email_identity,
             LoginType.DUMMY: self._check_dummy_auth,
         }
+        self.bcrypt_rounds = hs.config.bcrypt_rounds
         self.sessions = {}
+        self.INVALID_TOKEN_HTTP_STATUS = 401
 
     @defer.inlineCallbacks
     def check_auth(self, flows, clientdict, clientip):
@@ -186,7 +188,7 @@ class AuthHandler(BaseHandler):
         # TODO: get this from the homeserver rather than creating a new one for
         # each request
         try:
-            client = SimpleHttpClient(self.hs)
+            client = self.hs.get_simple_http_client()
             resp_body = yield client.post_urlencoded_get_json(
                 self.hs.config.recaptcha_siteverify_api,
                 args={
@@ -279,7 +281,10 @@ class AuthHandler(BaseHandler):
             user_id (str): User ID
             password (str): Password
         Returns:
-            The access token for the user's session.
+            A tuple of:
+              The user's ID.
+              The access token for the user's session.
+              The refresh token for the user's session.
         Raises:
             StoreError if there was a problem storing the token.
             LoginError if there was an authentication problem.
@@ -287,11 +292,43 @@ class AuthHandler(BaseHandler):
         user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
         self._check_password(user_id, password, password_hash)
 
-        reg_handler = self.hs.get_handlers().registration_handler
-        access_token = reg_handler.generate_token(user_id)
         logger.info("Logging in user %s", user_id)
-        yield self.store.add_access_token_to_user(user_id, access_token)
-        defer.returnValue((user_id, access_token))
+        access_token = yield self.issue_access_token(user_id)
+        refresh_token = yield self.issue_refresh_token(user_id)
+        defer.returnValue((user_id, access_token, refresh_token))
+
+    @defer.inlineCallbacks
+    def get_login_tuple_for_user_id(self, user_id):
+        """
+        Gets login tuple for the user with the given user ID.
+        The user is assumed to have been authenticated by some other
+        machanism (e.g. CAS)
+
+        Args:
+            user_id (str): User ID
+        Returns:
+            A tuple of:
+              The user's ID.
+              The access token for the user's session.
+              The refresh token for the user's session.
+        Raises:
+            StoreError if there was a problem storing the token.
+            LoginError if there was an authentication problem.
+        """
+        user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
+
+        logger.info("Logging in user %s", user_id)
+        access_token = yield self.issue_access_token(user_id)
+        refresh_token = yield self.issue_refresh_token(user_id)
+        defer.returnValue((user_id, access_token, refresh_token))
+
+    @defer.inlineCallbacks
+    def does_user_exist(self, user_id):
+        try:
+            yield self._find_user_id_and_pwd_hash(user_id)
+            defer.returnValue(True)
+        except LoginError:
+            defer.returnValue(False)
 
     @defer.inlineCallbacks
     def _find_user_id_and_pwd_hash(self, user_id):
@@ -321,13 +358,82 @@ class AuthHandler(BaseHandler):
 
     def _check_password(self, user_id, password, stored_hash):
         """Checks that user_id has passed password, raises LoginError if not."""
-        if not bcrypt.checkpw(password, stored_hash):
+        if not self.validate_hash(password, stored_hash):
             logger.warn("Failed password login for user %s", user_id)
             raise LoginError(403, "", errcode=Codes.FORBIDDEN)
 
     @defer.inlineCallbacks
+    def issue_access_token(self, user_id):
+        access_token = self.generate_access_token(user_id)
+        yield self.store.add_access_token_to_user(user_id, access_token)
+        defer.returnValue(access_token)
+
+    @defer.inlineCallbacks
+    def issue_refresh_token(self, user_id):
+        refresh_token = self.generate_refresh_token(user_id)
+        yield self.store.add_refresh_token_to_user(user_id, refresh_token)
+        defer.returnValue(refresh_token)
+
+    def generate_access_token(self, user_id, extra_caveats=None):
+        extra_caveats = extra_caveats or []
+        macaroon = self._generate_base_macaroon(user_id)
+        macaroon.add_first_party_caveat("type = access")
+        now = self.hs.get_clock().time_msec()
+        expiry = now + (60 * 60 * 1000)
+        macaroon.add_first_party_caveat("time < %d" % (expiry,))
+        for caveat in extra_caveats:
+            macaroon.add_first_party_caveat(caveat)
+        return macaroon.serialize()
+
+    def generate_refresh_token(self, user_id):
+        m = self._generate_base_macaroon(user_id)
+        m.add_first_party_caveat("type = refresh")
+        # Important to add a nonce, because otherwise every refresh token for a
+        # user will be the same.
+        m.add_first_party_caveat("nonce = %s" % (
+            stringutils.random_string_with_symbols(16),
+        ))
+        return m.serialize()
+
+    def generate_short_term_login_token(self, user_id):
+        macaroon = self._generate_base_macaroon(user_id)
+        macaroon.add_first_party_caveat("type = login")
+        now = self.hs.get_clock().time_msec()
+        expiry = now + (2 * 60 * 1000)
+        macaroon.add_first_party_caveat("time < %d" % (expiry,))
+        return macaroon.serialize()
+
+    def validate_short_term_login_token_and_get_user_id(self, login_token):
+        try:
+            macaroon = pymacaroons.Macaroon.deserialize(login_token)
+            auth_api = self.hs.get_auth()
+            auth_api.validate_macaroon(macaroon, "login", [auth_api.verify_expiry])
+            return self._get_user_from_macaroon(macaroon)
+        except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+            raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
+
+    def _generate_base_macaroon(self, user_id):
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+        return macaroon
+
+    def _get_user_from_macaroon(self, macaroon):
+        user_prefix = "user_id = "
+        for caveat in macaroon.caveats:
+            if caveat.caveat_id.startswith(user_prefix):
+                return caveat.caveat_id[len(user_prefix):]
+        raise AuthError(
+            self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
+            errcode=Codes.UNKNOWN_TOKEN
+        )
+
+    @defer.inlineCallbacks
     def set_password(self, user_id, newpassword):
-        password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
+        password_hash = self.hash(newpassword)
 
         yield self.store.user_set_password_hash(user_id, password_hash)
         yield self.store.user_delete_access_tokens(user_id)
@@ -349,3 +455,26 @@ class AuthHandler(BaseHandler):
     def _remove_session(self, session):
         logger.debug("Removing session %s", session)
         del self.sessions[session["id"]]
+
+    def hash(self, password):
+        """Computes a secure hash of password.
+
+        Args:
+            password (str): Password to hash.
+
+        Returns:
+            Hashed password (str).
+        """
+        return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
+
+    def validate_hash(self, password, stored_hash):
+        """Validates that self.hash(password) == stored_hash.
+
+        Args:
+            password (str): Password to hash.
+            stored_hash (str): Expected hash value.
+
+        Returns:
+            Whether self.hash(password) == stored_hash (bool).
+        """
+        return bcrypt.checkpw(password, stored_hash)