summary refs log tree commit diff
path: root/synapse/rest/client/v1/login.py
diff options
context:
space:
mode:
authorMark Haines <mjark@negativecurvature.net>2016-01-04 14:02:50 +0000
committerMark Haines <mjark@negativecurvature.net>2016-01-04 14:02:50 +0000
commitf35f8d06ea58e2d0cdccd82924c7a44fd93f4c38 (patch)
treedc5312558565f8ac01264be21d388e563a5c8c58 /synapse/rest/client/v1/login.py
parentAdded info abou Martin Giess' auto-deployment process with vagrant/ansible (diff)
parentBump changelog and version for v0.12.0 (diff)
downloadsynapse-f35f8d06ea58e2d0cdccd82924c7a44fd93f4c38.tar.xz
Merge remote-tracking branch 'origin/release-v0.12.0' v0.12.0
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r--synapse/rest/client/v1/login.py19
1 files changed, 8 insertions, 11 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 720d6358e7..e8c35508cd 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -16,9 +16,8 @@
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, LoginError, Codes
-from synapse.http.client import SimpleHttpClient
 from synapse.types import UserID
-from base import ClientV1RestServlet, client_path_pattern
+from base import ClientV1RestServlet, client_path_patterns
 
 import simplejson as json
 import urllib
@@ -36,7 +35,7 @@ logger = logging.getLogger(__name__)
 
 
 class LoginRestServlet(ClientV1RestServlet):
-    PATTERN = client_path_pattern("/login$")
+    PATTERNS = client_path_patterns("/login$")
     PASS_TYPE = "m.login.password"
     SAML2_TYPE = "m.login.saml2"
     CAS_TYPE = "m.login.cas"
@@ -51,6 +50,7 @@ class LoginRestServlet(ClientV1RestServlet):
         self.cas_server_url = hs.config.cas_server_url
         self.cas_required_attributes = hs.config.cas_required_attributes
         self.servername = hs.config.server_name
+        self.http_client = hs.get_simple_http_client()
 
     def on_GET(self, request):
         flows = []
@@ -98,15 +98,12 @@ class LoginRestServlet(ClientV1RestServlet):
             # TODO Delete this after all CAS clients switch to token login instead
             elif self.cas_enabled and (login_submission["type"] ==
                                        LoginRestServlet.CAS_TYPE):
-                # TODO: get this from the homeserver rather than creating a new one for
-                # each request
-                http_client = SimpleHttpClient(self.hs)
                 uri = "%s/proxyValidate" % (self.cas_server_url,)
                 args = {
                     "ticket": login_submission["ticket"],
                     "service": login_submission["service"]
                 }
-                body = yield http_client.get_raw(uri, args)
+                body = yield self.http_client.get_raw(uri, args)
                 result = yield self.do_cas_login(body)
                 defer.returnValue(result)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
@@ -238,7 +235,7 @@ class LoginRestServlet(ClientV1RestServlet):
 
 
 class SAML2RestServlet(ClientV1RestServlet):
-    PATTERN = client_path_pattern("/login/saml2")
+    PATTERNS = client_path_patterns("/login/saml2", releases=())
 
     def __init__(self, hs):
         super(SAML2RestServlet, self).__init__(hs)
@@ -282,7 +279,7 @@ class SAML2RestServlet(ClientV1RestServlet):
 
 # TODO Delete this after all CAS clients switch to token login instead
 class CasRestServlet(ClientV1RestServlet):
-    PATTERN = client_path_pattern("/login/cas")
+    PATTERNS = client_path_patterns("/login/cas", releases=())
 
     def __init__(self, hs):
         super(CasRestServlet, self).__init__(hs)
@@ -293,7 +290,7 @@ class CasRestServlet(ClientV1RestServlet):
 
 
 class CasRedirectServlet(ClientV1RestServlet):
-    PATTERN = client_path_pattern("/login/cas/redirect")
+    PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
 
     def __init__(self, hs):
         super(CasRedirectServlet, self).__init__(hs)
@@ -316,7 +313,7 @@ class CasRedirectServlet(ClientV1RestServlet):
 
 
 class CasTicketServlet(ClientV1RestServlet):
-    PATTERN = client_path_pattern("/login/cas/ticket")
+    PATTERNS = client_path_patterns("/login/cas/ticket", releases=())
 
     def __init__(self, hs):
         super(CasTicketServlet, self).__init__(hs)