summary refs log tree commit diff
path: root/synapse/handlers/register.py
diff options
context:
space:
mode:
authorDavid Baker <dave@matrix.org>2015-04-16 19:56:44 +0100
committerDavid Baker <dave@matrix.org>2015-04-16 19:56:44 +0100
commitea1776f556edaf6ca483bc5faed5e9d244aa1a15 (patch)
tree41d1cc129f4ed2eb5661ee2e0fece413943a6340 /synapse/handlers/register.py
parentDummy login so we can do the first POST request to get login flows without it... (diff)
downloadsynapse-ea1776f556edaf6ca483bc5faed5e9d244aa1a15.tar.xz
Return user ID in use error straight away
Diffstat (limited to 'synapse/handlers/register.py')
-rw-r--r--synapse/handlers/register.py102
1 files changed, 39 insertions, 63 deletions
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 6759a8c582..541b1019da 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -45,6 +45,36 @@ class RegistrationHandler(BaseHandler):
         self.distributor.declare("registered_user")
 
     @defer.inlineCallbacks
+    def check_username(self, localpart):
+        yield run_on_reactor()
+
+        print "checking username %s" % (localpart)
+
+        if urllib.quote(localpart) != localpart:
+            raise SynapseError(
+                400,
+                "User ID must only contain characters which do not"
+                " require URL encoding."
+            )
+
+        user = UserID(localpart, self.hs.hostname)
+        user_id = user.to_string()
+
+        yield self.check_user_id_is_valid(user_id)
+
+        print "is valid"
+
+        u = yield self.store.get_user_by_id(user_id)
+        print "user is: "
+        print u
+        if u:
+            raise SynapseError(
+                400,
+                "User ID already taken.",
+                errcode=Codes.USER_IN_USE,
+            )
+
+    @defer.inlineCallbacks
     def register(self, localpart=None, password=None):
         """Registers a new client on the server.
 
@@ -64,18 +94,11 @@ class RegistrationHandler(BaseHandler):
             password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
 
         if localpart:
-            if localpart and urllib.quote(localpart) != localpart:
-                raise SynapseError(
-                    400,
-                    "User ID must only contain characters which do not"
-                    " require URL encoding."
-                )
+            self.check_username(localpart)
 
             user = UserID(localpart, self.hs.hostname)
             user_id = user.to_string()
 
-            yield self.check_user_id_is_valid(user_id)
-
             token = self._generate_token(user_id)
             yield self.store.register(
                 user_id=user_id,
@@ -190,7 +213,8 @@ class RegistrationHandler(BaseHandler):
             logger.info("validating theeepidcred sid %s on id server %s",
                         c['sid'], c['idServer'])
             try:
-                threepid = yield self._threepid_from_creds(c)
+                identity_handler = self.hs.get_handlers().identity_handler
+                threepid = yield identity_handler.threepid_from_creds(c)
             except:
                 logger.exception("Couldn't validate 3pid")
                 raise RegistrationError(400, "Couldn't validate 3pid")
@@ -202,12 +226,16 @@ class RegistrationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def bind_emails(self, user_id, threepidCreds):
-        """Links emails with a user ID and informs an identity server."""
+        """Links emails with a user ID and informs an identity server.
+
+        Used only by c/s api v1
+        """
 
         # Now we have a matrix ID, bind it to the threepids we were given
         for c in threepidCreds:
+            identity_handler = self.hs.get_handlers().identity_handler
             # XXX: This should be a deferred list, shouldn't it?
-            yield self._bind_threepid(c, user_id)
+            yield identity_handler.bind_threepid(c, user_id)
 
     @defer.inlineCallbacks
     def check_user_id_is_valid(self, user_id):
@@ -235,58 +263,6 @@ class RegistrationHandler(BaseHandler):
         return "-" + stringutils.random_string(18)
 
     @defer.inlineCallbacks
-    def _threepid_from_creds(self, creds):
-        # TODO: get this from the homeserver rather than creating a new one for
-        # each request
-        http_client = SimpleHttpClient(self.hs)
-        # XXX: make this configurable!
-        trustedIdServers = ['matrix.org:8090', 'matrix.org']
-        if not creds['idServer'] in trustedIdServers:
-            logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
-                        'credentials', creds['idServer'])
-            defer.returnValue(None)
-
-        data = {}
-        try:
-            data = yield http_client.get_json(
-                # XXX: This should be HTTPS
-                "http://%s%s" % (
-                    creds['idServer'],
-                    "/_matrix/identity/api/v1/3pid/getValidated3pid"
-                ),
-                {'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
-            )
-        except CodeMessageException as e:
-            data = json.loads(e.msg)
-
-        if 'medium' in data:
-            defer.returnValue(data)
-        defer.returnValue(None)
-
-    @defer.inlineCallbacks
-    def _bind_threepid(self, creds, mxid):
-        yield
-        logger.debug("binding threepid")
-        http_client = SimpleHttpClient(self.hs)
-        data = None
-        try:
-            data = yield http_client.post_urlencoded_get_json(
-                # XXX: Change when ID servers are all HTTPS
-                "http://%s%s" % (
-                    creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
-                ),
-                {
-                    'sid': creds['sid'],
-                    'clientSecret': creds['clientSecret'],
-                    'mxid': mxid,
-                }
-            )
-            logger.debug("bound threepid")
-        except CodeMessageException as e:
-            data = json.loads(e.msg)
-        defer.returnValue(data)
-
-    @defer.inlineCallbacks
     def _validate_captcha(self, ip_addr, private_key, challenge, response):
         """Validates the captcha provided.