diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 602c5bcd89..59f687e0f1 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -26,6 +26,7 @@ from twisted.web.client import PartialDownloadError
import logging
import bcrypt
+import pymacaroons
import simplejson
import synapse.util.stringutils as stringutils
@@ -279,7 +280,10 @@ class AuthHandler(BaseHandler):
user_id (str): User ID
password (str): Password
Returns:
- The access token for the user's session.
+ A tuple of:
+ The user's ID.
+ The access token for the user's session.
+ The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
@@ -287,11 +291,10 @@ class AuthHandler(BaseHandler):
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
- reg_handler = self.hs.get_handlers().registration_handler
- access_token = reg_handler.generate_token(user_id)
logger.info("Logging in user %s", user_id)
- yield self.store.add_access_token_to_user(user_id, access_token)
- defer.returnValue((user_id, access_token))
+ access_token = yield self.issue_access_token(user_id)
+ refresh_token = yield self.issue_refresh_token(user_id)
+ defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
@@ -321,13 +324,52 @@ class AuthHandler(BaseHandler):
def _check_password(self, user_id, password, stored_hash):
"""Checks that user_id has passed password, raises LoginError if not."""
- if not bcrypt.checkpw(password, stored_hash):
+ if not self.validate_hash(password, stored_hash):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
+ def issue_access_token(self, user_id):
+ access_token = self.generate_access_token(user_id)
+ yield self.store.add_access_token_to_user(user_id, access_token)
+ defer.returnValue(access_token)
+
+ @defer.inlineCallbacks
+ def issue_refresh_token(self, user_id):
+ refresh_token = self.generate_refresh_token(user_id)
+ yield self.store.add_refresh_token_to_user(user_id, refresh_token)
+ defer.returnValue(refresh_token)
+
+ def generate_access_token(self, user_id):
+ macaroon = self._generate_base_macaroon(user_id)
+ macaroon.add_first_party_caveat("type = access")
+ now = self.hs.get_clock().time_msec()
+ expiry = now + (60 * 60 * 1000)
+ macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ return macaroon.serialize()
+
+ def generate_refresh_token(self, user_id):
+ m = self._generate_base_macaroon(user_id)
+ m.add_first_party_caveat("type = refresh")
+ # Important to add a nonce, because otherwise every refresh token for a
+ # user will be the same.
+ m.add_first_party_caveat("nonce = %s" % (
+ stringutils.random_string_with_symbols(16),
+ ))
+ return m.serialize()
+
+ def _generate_base_macaroon(self, user_id):
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+ return macaroon
+
+ @defer.inlineCallbacks
def set_password(self, user_id, newpassword):
- password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
+ password_hash = self.hash(newpassword)
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(user_id)
@@ -349,3 +391,26 @@ class AuthHandler(BaseHandler):
def _remove_session(self, session):
logger.debug("Removing session %s", session)
del self.sessions[session["id"]]
+
+ def hash(self, password):
+ """Computes a secure hash of password.
+
+ Args:
+ password (str): Password to hash.
+
+ Returns:
+ Hashed password (str).
+ """
+ return bcrypt.hashpw(password, bcrypt.gensalt())
+
+ def validate_hash(self, password, stored_hash):
+ """Validates that self.hash(password) == stored_hash.
+
+ Args:
+ password (str): Password to hash.
+ stored_hash (str): Expected hash value.
+
+ Returns:
+ Whether self.hash(password) == stored_hash (bool).
+ """
+ return bcrypt.checkpw(password, stored_hash)
|