summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorDaniel Wagner-Hall <dawagner@gmail.com>2015-08-26 13:49:38 +0100
committerDaniel Wagner-Hall <dawagner@gmail.com>2015-08-26 13:49:38 +0100
commit6f0c344ca7fd6e5dd6109d47ecdf90fa63f75f71 (patch)
tree22fc258521ec6c2e46c84a1cd579deb2ca8a7cfb /synapse
parentMerge pull request #253 from matrix-org/tox (diff)
parentMerge erikj/user_dedup to develop (diff)
downloadsynapse-6f0c344ca7fd6e5dd6109d47ecdf90fa63f75f71.tar.xz
Merge pull request #255 from matrix-org/mergeeriksmadness
Merge erikj/user_dedup to develop
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/auth.py39
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/rest/client/v1/login.py5
-rw-r--r--synapse/storage/registration.py14
4 files changed, 50 insertions, 12 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c983d444e8..1ab19cd1a6 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -163,7 +163,8 @@ class AuthHandler(BaseHandler):
         if not user_id.startswith('@'):
             user_id = UserID.create(user_id, self.hs.hostname).to_string()
 
-        yield self._check_password(user_id, password)
+        user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
+        self._check_password(user_id, password, password_hash)
         defer.returnValue(user_id)
 
     @defer.inlineCallbacks
@@ -280,27 +281,49 @@ class AuthHandler(BaseHandler):
             password (str): Password
         Returns:
             A tuple of:
+              The user's ID.
               The access token for the user's session.
               The refresh token for the user's session.
         Raises:
             StoreError if there was a problem storing the token.
             LoginError if there was an authentication problem.
         """
-        yield self._check_password(user_id, password)
+        user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
+        self._check_password(user_id, password, password_hash)
+
         logger.info("Logging in user %s", user_id)
         access_token = yield self.issue_access_token(user_id)
         refresh_token = yield self.issue_refresh_token(user_id)
-        defer.returnValue((access_token, refresh_token))
+        defer.returnValue((user_id, access_token, refresh_token))
 
     @defer.inlineCallbacks
-    def _check_password(self, user_id, password):
-        """Checks that user_id has passed password, raises LoginError if not."""
-        user_info = yield self.store.get_user_by_id(user_id=user_id)
-        if not user_info:
+    def _find_user_id_and_pwd_hash(self, user_id):
+        """Checks to see if a user with the given id exists. Will check case
+        insensitively, but will throw if there are multiple inexact matches.
+
+        Returns:
+            tuple: A 2-tuple of `(canonical_user_id, password_hash)`
+        """
+        user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
+        if not user_infos:
             logger.warn("Attempted to login as %s but they do not exist", user_id)
             raise LoginError(403, "", errcode=Codes.FORBIDDEN)
 
-        stored_hash = user_info["password_hash"]
+        if len(user_infos) > 1:
+            if user_id not in user_infos:
+                logger.warn(
+                    "Attempted to login as %s but it matches more than one user "
+                    "inexactly: %r",
+                    user_id, user_infos.keys()
+                )
+                raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+            defer.returnValue((user_id, user_infos[user_id]))
+        else:
+            defer.returnValue(user_infos.popitem())
+
+    def _check_password(self, user_id, password, stored_hash):
+        """Checks that user_id has passed password, raises LoginError if not."""
         if not bcrypt.checkpw(password, stored_hash):
             logger.warn("Failed password login for user %s", user_id)
             raise LoginError(403, "", errcode=Codes.FORBIDDEN)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 3d1b6531c2..56d125f753 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -56,8 +56,8 @@ class RegistrationHandler(BaseHandler):
 
         yield self.check_user_id_is_valid(user_id)
 
-        u = yield self.store.get_user_by_id(user_id)
-        if u:
+        users = yield self.store.get_users_by_id_case_insensitive(user_id)
+        if users:
             raise SynapseError(
                 400,
                 "User ID already taken.",
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 3a0707c2ee..e580f71964 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -83,10 +83,11 @@ class LoginRestServlet(ClientV1RestServlet):
 
         if not user_id.startswith('@'):
             user_id = UserID.create(
-                user_id, self.hs.hostname).to_string()
+                user_id, self.hs.hostname
+            ).to_string()
 
         auth_handler = self.handlers.auth_handler
-        access_token, refresh_token = yield auth_handler.login_with_password(
+        user_id, access_token, refresh_token = yield auth_handler.login_with_password(
             user_id=user_id,
             password=login_submission["password"])
 
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index a2d0f7c4b1..c9ceb132ae 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -120,6 +120,20 @@ class RegistrationStore(SQLBaseStore):
             allow_none=True,
         )
 
+    def get_users_by_id_case_insensitive(self, user_id):
+        """Gets users that match user_id case insensitively.
+        Returns a mapping of user_id -> password_hash.
+        """
+        def f(txn):
+            sql = (
+                "SELECT name, password_hash FROM users"
+                " WHERE lower(name) = lower(?)"
+            )
+            txn.execute(sql, (user_id,))
+            return dict(txn.fetchall())
+
+        return self.runInteraction("get_users_by_id_case_insensitive", f)
+
     @defer.inlineCallbacks
     def user_set_password_hash(self, user_id, password_hash):
         """