summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/auth.py33
-rw-r--r--synapse/rest/client/v1/login.py5
2 files changed, 31 insertions, 7 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index cc667b6d8b..0337be36c2 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -270,6 +270,7 @@ class AuthHandler(BaseHandler):
         sess = self._get_session_info(session_id)
         return sess.setdefault('serverdict', {}).get(key, default)
 
+    @defer.inlineCallbacks
     def _check_password_auth(self, authdict, _):
         if "user" not in authdict or "password" not in authdict:
             raise LoginError(400, "", Codes.MISSING_PARAM)
@@ -277,10 +278,11 @@ class AuthHandler(BaseHandler):
         user_id = authdict["user"]
         password = authdict["password"]
 
-        return self.validate_login(user_id, {
+        (canonical_id, callback) = yield self.validate_login(user_id, {
             "type": LoginType.PASSWORD,
             "password": password,
         })
+        defer.returnValue(canonical_id)
 
     @defer.inlineCallbacks
     def _check_recaptcha(self, authdict, clientip):
@@ -517,7 +519,8 @@ class AuthHandler(BaseHandler):
             login_submission (dict): the whole of the login submission
                 (including 'type' and other relevant fields)
         Returns:
-            Deferred[str]: canonical user id
+            Deferred[str, func]: canonical user id, and optional callback
+                to be called once the access token and device id are issued
         Raises:
             StoreError if there was a problem accessing the database
             SynapseError if there was a problem with the request
@@ -581,11 +584,13 @@ class AuthHandler(BaseHandler):
                     ),
                 )
 
-            returned_user_id = yield provider.check_auth(
+            result = yield provider.check_auth(
                 username, login_type, login_dict,
             )
-            if returned_user_id:
-                defer.returnValue(returned_user_id)
+            if result:
+                if isinstance(result, str):
+                    result = (result, None)
+                defer.returnValue(result)
 
         if login_type == LoginType.PASSWORD:
             known_login_type = True
@@ -595,7 +600,7 @@ class AuthHandler(BaseHandler):
             )
 
             if canonical_user_id:
-                defer.returnValue(canonical_user_id)
+                defer.returnValue((canonical_user_id, None))
 
         if not known_login_type:
             raise SynapseError(400, "Unknown login type %s" % login_type)
@@ -848,6 +853,7 @@ class _AccountHandler(object):
         self.hs = hs
 
         self._check_user_exists = check_user_exists
+        self._store = hs.get_datastore()
 
     def get_qualified_user_id(self, username):
         """Qualify a user id, if necessary
@@ -885,3 +891,18 @@ class _AccountHandler(object):
         """
         reg = self.hs.get_handlers().registration_handler
         return reg.register(localpart=localpart)
+
+    def run_db_interaction(self, desc, func, *args, **kwargs):
+        """Run a function with a database connection
+
+        Args:
+            desc (str): description for the transaction, for metrics etc
+            func (func): function to be run. Passed a database cursor object
+                as well as *args and **kwargs
+            *args: positional args to be passed to func
+            **kwargs: named args to be passed to func
+
+        Returns:
+            Deferred[object]: result of func
+        """
+        return self._store.runInteraction(desc, func, *args, **kwargs)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d25a68e753..5669ecb724 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -219,7 +219,7 @@ class LoginRestServlet(ClientV1RestServlet):
             raise SynapseError(400, "User identifier is missing 'user' key")
 
         auth_handler = self.auth_handler
-        canonical_user_id = yield auth_handler.validate_login(
+        canonical_user_id, callback = yield auth_handler.validate_login(
             identifier["user"],
             login_submission,
         )
@@ -238,6 +238,9 @@ class LoginRestServlet(ClientV1RestServlet):
             "device_id": device_id,
         }
 
+        if callback is not None:
+            yield callback(result)
+
         defer.returnValue((200, result))
 
     @defer.inlineCallbacks