diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 602c5bcd89..be157e2bb7 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,14 +18,14 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
-from synapse.api.errors import LoginError, Codes
-from synapse.http.client import SimpleHttpClient
+from synapse.api.errors import AuthError, LoginError, Codes
from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError
import logging
import bcrypt
+import pymacaroons
import simplejson
import synapse.util.stringutils as stringutils
@@ -44,7 +44,9 @@ class AuthHandler(BaseHandler):
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.DUMMY: self._check_dummy_auth,
}
+ self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {}
+ self.INVALID_TOKEN_HTTP_STATUS = 401
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
@@ -186,7 +188,7 @@ class AuthHandler(BaseHandler):
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
- client = SimpleHttpClient(self.hs)
+ client = self.hs.get_simple_http_client()
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={
@@ -279,7 +281,10 @@ class AuthHandler(BaseHandler):
user_id (str): User ID
password (str): Password
Returns:
- The access token for the user's session.
+ A tuple of:
+ The user's ID.
+ The access token for the user's session.
+ The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
@@ -287,11 +292,43 @@ class AuthHandler(BaseHandler):
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
- reg_handler = self.hs.get_handlers().registration_handler
- access_token = reg_handler.generate_token(user_id)
logger.info("Logging in user %s", user_id)
- yield self.store.add_access_token_to_user(user_id, access_token)
- defer.returnValue((user_id, access_token))
+ access_token = yield self.issue_access_token(user_id)
+ refresh_token = yield self.issue_refresh_token(user_id)
+ defer.returnValue((user_id, access_token, refresh_token))
+
+ @defer.inlineCallbacks
+ def get_login_tuple_for_user_id(self, user_id):
+ """
+ Gets login tuple for the user with the given user ID.
+ The user is assumed to have been authenticated by some other
+ machanism (e.g. CAS)
+
+ Args:
+ user_id (str): User ID
+ Returns:
+ A tuple of:
+ The user's ID.
+ The access token for the user's session.
+ The refresh token for the user's session.
+ Raises:
+ StoreError if there was a problem storing the token.
+ LoginError if there was an authentication problem.
+ """
+ user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
+
+ logger.info("Logging in user %s", user_id)
+ access_token = yield self.issue_access_token(user_id)
+ refresh_token = yield self.issue_refresh_token(user_id)
+ defer.returnValue((user_id, access_token, refresh_token))
+
+ @defer.inlineCallbacks
+ def does_user_exist(self, user_id):
+ try:
+ yield self._find_user_id_and_pwd_hash(user_id)
+ defer.returnValue(True)
+ except LoginError:
+ defer.returnValue(False)
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
@@ -321,13 +358,82 @@ class AuthHandler(BaseHandler):
def _check_password(self, user_id, password, stored_hash):
"""Checks that user_id has passed password, raises LoginError if not."""
- if not bcrypt.checkpw(password, stored_hash):
+ if not self.validate_hash(password, stored_hash):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
+ def issue_access_token(self, user_id):
+ access_token = self.generate_access_token(user_id)
+ yield self.store.add_access_token_to_user(user_id, access_token)
+ defer.returnValue(access_token)
+
+ @defer.inlineCallbacks
+ def issue_refresh_token(self, user_id):
+ refresh_token = self.generate_refresh_token(user_id)
+ yield self.store.add_refresh_token_to_user(user_id, refresh_token)
+ defer.returnValue(refresh_token)
+
+ def generate_access_token(self, user_id, extra_caveats=None):
+ extra_caveats = extra_caveats or []
+ macaroon = self._generate_base_macaroon(user_id)
+ macaroon.add_first_party_caveat("type = access")
+ now = self.hs.get_clock().time_msec()
+ expiry = now + (60 * 60 * 1000)
+ macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ for caveat in extra_caveats:
+ macaroon.add_first_party_caveat(caveat)
+ return macaroon.serialize()
+
+ def generate_refresh_token(self, user_id):
+ m = self._generate_base_macaroon(user_id)
+ m.add_first_party_caveat("type = refresh")
+ # Important to add a nonce, because otherwise every refresh token for a
+ # user will be the same.
+ m.add_first_party_caveat("nonce = %s" % (
+ stringutils.random_string_with_symbols(16),
+ ))
+ return m.serialize()
+
+ def generate_short_term_login_token(self, user_id):
+ macaroon = self._generate_base_macaroon(user_id)
+ macaroon.add_first_party_caveat("type = login")
+ now = self.hs.get_clock().time_msec()
+ expiry = now + (2 * 60 * 1000)
+ macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ return macaroon.serialize()
+
+ def validate_short_term_login_token_and_get_user_id(self, login_token):
+ try:
+ macaroon = pymacaroons.Macaroon.deserialize(login_token)
+ auth_api = self.hs.get_auth()
+ auth_api.validate_macaroon(macaroon, "login", [auth_api.verify_expiry])
+ return self._get_user_from_macaroon(macaroon)
+ except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+ raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
+
+ def _generate_base_macaroon(self, user_id):
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+ return macaroon
+
+ def _get_user_from_macaroon(self, macaroon):
+ user_prefix = "user_id = "
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith(user_prefix):
+ return caveat.caveat_id[len(user_prefix):]
+ raise AuthError(
+ self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+
+ @defer.inlineCallbacks
def set_password(self, user_id, newpassword):
- password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
+ password_hash = self.hash(newpassword)
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(user_id)
@@ -349,3 +455,26 @@ class AuthHandler(BaseHandler):
def _remove_session(self, session):
logger.debug("Removing session %s", session)
del self.sessions[session["id"]]
+
+ def hash(self, password):
+ """Computes a secure hash of password.
+
+ Args:
+ password (str): Password to hash.
+
+ Returns:
+ Hashed password (str).
+ """
+ return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
+
+ def validate_hash(self, password, stored_hash):
+ """Validates that self.hash(password) == stored_hash.
+
+ Args:
+ password (str): Password to hash.
+ stored_hash (str): Expected hash value.
+
+ Returns:
+ Whether self.hash(password) == stored_hash (bool).
+ """
+ return bcrypt.checkpw(password, stored_hash)
|