summary refs log tree commit diff
diff options
context:
space:
mode:
authorDaniel Wagner-Hall <daniel@matrix.org>2015-08-20 11:35:56 +0100
committerDaniel Wagner-Hall <daniel@matrix.org>2015-08-20 11:35:56 +0100
commit617501dd2a0562f4bf7edf8bc7a4e8aeb16b2254 (patch)
tree566d84bc8f398e21d4035ca9a0147c49363c8eec
parentMerge branch 'auth' into refresh (diff)
downloadsynapse-617501dd2a0562f4bf7edf8bc7a4e8aeb16b2254.tar.xz
Move token generation to auth handler
I prefer the auth handler to worry about all auth, and register to call
into it as needed, than to smatter auth logic between the two.
-rw-r--r--synapse/handlers/auth.py29
-rw-r--r--synapse/handlers/register.py26
-rw-r--r--tests/handlers/test_auth.py (renamed from tests/handlers/test_register.py)14
3 files changed, 38 insertions, 31 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index be2baeaece..0bf917efdd 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -26,6 +26,7 @@ from twisted.web.client import PartialDownloadError
 
 import logging
 import bcrypt
+import pymacaroons
 import simplejson
 
 import synapse.util.stringutils as stringutils
@@ -284,12 +285,9 @@ class AuthHandler(BaseHandler):
             LoginError if there was an authentication problem.
         """
         yield self._check_password(user_id, password)
-
-        reg_handler = self.hs.get_handlers().registration_handler
-        access_token = reg_handler.generate_token(user_id)
         logger.info("Logging in user %s", user_id)
-        yield self.store.add_access_token_to_user(user_id, access_token)
-        defer.returnValue(access_token)
+        token = yield self.issue_access_token(user_id)
+        defer.returnValue(token)
 
     @defer.inlineCallbacks
     def _check_password(self, user_id, password):
@@ -305,6 +303,27 @@ class AuthHandler(BaseHandler):
             raise LoginError(403, "", errcode=Codes.FORBIDDEN)
 
     @defer.inlineCallbacks
+    def issue_access_token(self, user_id):
+        reg_handler = self.hs.get_handlers().registration_handler
+        access_token = reg_handler.generate_access_token(user_id)
+        yield self.store.add_access_token_to_user(user_id, access_token)
+        defer.returnValue(access_token)
+
+    def generate_access_token(self, user_id):
+        macaroon = pymacaroons.Macaroon(
+            location = self.hs.config.server_name,
+            identifier = "key",
+            key = self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+        macaroon.add_first_party_caveat("type = access")
+        now = self.hs.get_clock().time_msec()
+        expiry = now + (60 * 60 * 1000)
+        macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
+        return macaroon.serialize()
+
+    @defer.inlineCallbacks
     def set_password(self, user_id, newpassword):
         password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c391c1bdf5..3d1b6531c2 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -27,7 +27,6 @@ from synapse.http.client import CaptchaServerHttpClient
 
 import bcrypt
 import logging
-import pymacaroons
 import urllib
 
 logger = logging.getLogger(__name__)
@@ -91,7 +90,7 @@ class RegistrationHandler(BaseHandler):
             user = UserID(localpart, self.hs.hostname)
             user_id = user.to_string()
 
-            token = self.generate_token(user_id)
+            token = self.auth_handler().generate_access_token(user_id)
             yield self.store.register(
                 user_id=user_id,
                 token=token,
@@ -111,7 +110,7 @@ class RegistrationHandler(BaseHandler):
                     user_id = user.to_string()
                     yield self.check_user_id_is_valid(user_id)
 
-                    token = self.generate_token(user_id)
+                    token = self.auth_handler().generate_access_token(user_id)
                     yield self.store.register(
                         user_id=user_id,
                         token=token,
@@ -161,7 +160,7 @@ class RegistrationHandler(BaseHandler):
                 400, "Invalid user localpart for this application service.",
                 errcode=Codes.EXCLUSIVE
             )
-        token = self.generate_token(user_id)
+        token = self.auth_handler().generate_access_token(user_id)
         yield self.store.register(
             user_id=user_id,
             token=token,
@@ -208,7 +207,7 @@ class RegistrationHandler(BaseHandler):
         user_id = user.to_string()
 
         yield self.check_user_id_is_valid(user_id)
-        token = self.generate_token(user_id)
+        token = self.auth_handler().generate_access_token(user_id)
         try:
             yield self.store.register(
                 user_id=user_id,
@@ -273,20 +272,6 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE
                 )
 
-    def generate_token(self, user_id):
-        macaroon = pymacaroons.Macaroon(
-            location = self.hs.config.server_name,
-            identifier = "key",
-            key = self.hs.config.macaroon_secret_key)
-        macaroon.add_first_party_caveat("gen = 1")
-        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
-        macaroon.add_first_party_caveat("type = access")
-        now = self.hs.get_clock().time_msec()
-        expiry = now + (60 * 60 * 1000)
-        macaroon.add_first_party_caveat("time < %d" % (expiry,))
-
-        return macaroon.serialize()
-
     def _generate_user_id(self):
         return "-" + stringutils.random_string(18)
 
@@ -329,3 +314,6 @@ class RegistrationHandler(BaseHandler):
             }
         )
         defer.returnValue(data)
+
+    def auth_handler(self):
+        return self.hs.get_handlers().auth_handler
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_auth.py
index 91cc90242f..978e4d0d2e 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_auth.py
@@ -16,27 +16,27 @@
 import pymacaroons
 
 from mock import Mock, NonCallableMock
-from synapse.handlers.register import RegistrationHandler
+from synapse.handlers.auth import AuthHandler
 from tests import unittest
 from tests.utils import setup_test_homeserver
 from twisted.internet import defer
 
 
-class RegisterHandlers(object):
+class AuthHandlers(object):
     def __init__(self, hs):
-        self.registration_handler = RegistrationHandler(hs)
+        self.auth_handler = AuthHandler(hs)
 
 
-class RegisterTestCase(unittest.TestCase):
+class AuthTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.hs = yield setup_test_homeserver(handlers=None)
-        self.hs.handlers = RegisterHandlers(self.hs)
+        self.hs.handlers = AuthHandlers(self.hs)
 
     def test_token_is_a_macaroon(self):
         self.hs.config.macaroon_secret_key = "this key is a huge secret"
 
-        token = self.hs.handlers.registration_handler.generate_token("some_user")
+        token = self.hs.handlers.auth_handler.generate_access_token("some_user")
         # Check that we can parse the thing with pymacaroons
         macaroon = pymacaroons.Macaroon.deserialize(token)
         # The most basic of sanity checks
@@ -47,7 +47,7 @@ class RegisterTestCase(unittest.TestCase):
         self.hs.config.macaroon_secret_key = "this key is a massive secret"
         self.hs.clock.now = 5000
 
-        token = self.hs.handlers.registration_handler.generate_token("a_user")
+        token = self.hs.handlers.auth_handler.generate_access_token("a_user")
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
         def verify_gen(caveat):