summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
authorSteven Hammerton <steven.hammerton@openmarket.com>2015-11-11 11:20:23 +0000
committerSteven Hammerton <steven.hammerton@openmarket.com>2015-11-11 11:21:43 +0000
commit2b779af10fe5c39f6119acddb5290be2b2a5930f (patch)
tree77cede9eab89d7d30e8362d984e429dc1050e236 /synapse/rest
parentShare more code between macaroon validation (diff)
downloadsynapse-2b779af10fe5c39f6119acddb5290be2b2a5930f.tar.xz
Minor review fixes
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/v1/login.py23
1 files changed, 10 insertions, 13 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 5a2cedacb0..78c542a94a 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -146,7 +146,7 @@ class LoginRestServlet(ClientV1RestServlet):
             yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
         )
         user_id, access_token, refresh_token = (
-            yield auth_handler.login_with_user_id(user_id)
+            yield auth_handler.get_login_tuple_for_user_id(user_id)
         )
         result = {
             "user_id": user_id,  # may have changed
@@ -179,7 +179,7 @@ class LoginRestServlet(ClientV1RestServlet):
         user_exists = yield auth_handler.does_user_exist(user_id)
         if user_exists:
             user_id, access_token, refresh_token = (
-                yield auth_handler.login_with_user_id(user_id)
+                yield auth_handler.get_login_tuple_for_user_id(user_id)
             )
             result = {
                 "user_id": user_id,  # may have changed
@@ -304,7 +304,6 @@ class CasRedirectServlet(ClientV1RestServlet):
         })
         request.redirect("%s?%s" % (self.cas_server_url, serviceParam))
         request.finish()
-        defer.returnValue(None)
 
 
 class CasTicketServlet(ClientV1RestServlet):
@@ -318,21 +317,19 @@ class CasTicketServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        clientRedirectUrl = request.args["redirectUrl"][0]
-        # TODO: get this from the homeserver rather than creating a new one for
-        # each request
-        http_client = SimpleHttpClient(self.hs)
+        client_redirect_url = request.args["redirectUrl"][0]
+        http_client = self.hs.get_simple_http_client()
         uri = self.cas_server_url + "/proxyValidate"
         args = {
             "ticket": request.args["ticket"],
             "service": self.cas_service_url
         }
         body = yield http_client.get_raw(uri, args)
-        result = yield self.handle_cas_response(request, body, clientRedirectUrl)
+        result = yield self.handle_cas_response(request, body, client_redirect_url)
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def handle_cas_response(self, request, cas_response_body, clientRedirectUrl):
+    def handle_cas_response(self, request, cas_response_body, client_redirect_url):
         user, attributes = self.parse_cas_response(cas_response_body)
 
         for required_attribute, required_value in self.cas_required_attributes.items():
@@ -351,15 +348,15 @@ class CasTicketServlet(ClientV1RestServlet):
         auth_handler = self.handlers.auth_handler
         user_exists = yield auth_handler.does_user_exist(user_id)
         if not user_exists:
-            user_id, ignored = (
+            user_id, _ = (
                 yield self.handlers.registration_handler.register(localpart=user)
             )
 
         login_token = auth_handler.generate_short_term_login_token(user_id)
-        redirectUrl = self.add_login_token_to_redirect_url(clientRedirectUrl, login_token)
-        request.redirect(redirectUrl)
+        redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
+                                                            login_token)
+        request.redirect(redirect_url)
         request.finish()
-        defer.returnValue(None)
 
     def add_login_token_to_redirect_url(self, url, token):
         url_parts = list(urlparse.urlparse(url))