summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py64
1 files changed, 52 insertions, 12 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index aa5d89a9ac..7f8ddc99c6 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -162,7 +162,7 @@ class AuthHandler(BaseHandler):
         defer.returnValue(params)
 
     @defer.inlineCallbacks
-    def check_auth(self, flows, clientdict, clientip):
+    def check_auth(self, flows, clientdict, clientip, password_servlet=False):
         """
         Takes a dictionary sent by the client in the login / registration
         protocol and handles the User-Interactive Auth flow.
@@ -186,6 +186,16 @@ class AuthHandler(BaseHandler):
 
             clientip (str): The IP address of the client.
 
+            password_servlet (bool): Whether the request originated from
+                PasswordRestServlet.
+                XXX: This is a temporary hack to distinguish between checking
+                for threepid validations locally (in the case of password
+                resets) and using the identity server (in the case of binding
+                a 3PID during registration). Once we start using the
+                homeserver for both tasks, this distinction will no longer be
+                necessary.
+
+
         Returns:
             defer.Deferred[dict, dict, str]: a deferred tuple of
                 (creds, params, session_id).
@@ -241,7 +251,9 @@ class AuthHandler(BaseHandler):
         if 'type' in authdict:
             login_type = authdict['type']
             try:
-                result = yield self._check_auth_dict(authdict, clientip)
+                result = yield self._check_auth_dict(
+                    authdict, clientip, password_servlet=password_servlet,
+                )
                 if result:
                     creds[login_type] = result
                     self._save_session(session)
@@ -351,7 +363,7 @@ class AuthHandler(BaseHandler):
         return sess.setdefault('serverdict', {}).get(key, default)
 
     @defer.inlineCallbacks
-    def _check_auth_dict(self, authdict, clientip):
+    def _check_auth_dict(self, authdict, clientip, password_servlet=False):
         """Attempt to validate the auth dict provided by a client
 
         Args:
@@ -369,7 +381,13 @@ class AuthHandler(BaseHandler):
         login_type = authdict['type']
         checker = self.checkers.get(login_type)
         if checker is not None:
-            res = yield checker(authdict, clientip)
+            # XXX: Temporary workaround for having Synapse handle password resets
+            # See AuthHandler.check_auth for further details
+            res = yield checker(
+                authdict,
+                clientip=clientip,
+                password_servlet=password_servlet,
+            )
             defer.returnValue(res)
 
         # build a v1-login-style dict out of the authdict and fall back to the
@@ -383,7 +401,7 @@ class AuthHandler(BaseHandler):
         defer.returnValue(canonical_id)
 
     @defer.inlineCallbacks
-    def _check_recaptcha(self, authdict, clientip):
+    def _check_recaptcha(self, authdict, clientip, **kwargs):
         try:
             user_response = authdict["response"]
         except KeyError:
@@ -429,20 +447,20 @@ class AuthHandler(BaseHandler):
                 defer.returnValue(True)
         raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
 
-    def _check_email_identity(self, authdict, _):
-        return self._check_threepid('email', authdict)
+    def _check_email_identity(self, authdict, **kwargs):
+        return self._check_threepid('email', authdict, **kwargs)
 
-    def _check_msisdn(self, authdict, _):
+    def _check_msisdn(self, authdict, **kwargs):
         return self._check_threepid('msisdn', authdict)
 
-    def _check_dummy_auth(self, authdict, _):
+    def _check_dummy_auth(self, authdict, **kwargs):
         return defer.succeed(True)
 
-    def _check_terms_auth(self, authdict, _):
+    def _check_terms_auth(self, authdict, **kwargs):
         return defer.succeed(True)
 
     @defer.inlineCallbacks
-    def _check_threepid(self, medium, authdict):
+    def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs):
         if 'threepid_creds' not in authdict:
             raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
 
@@ -451,7 +469,29 @@ class AuthHandler(BaseHandler):
         identity_handler = self.hs.get_handlers().identity_handler
 
         logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
-        threepid = yield identity_handler.threepid_from_creds(threepid_creds)
+        if (
+            not password_servlet
+            or self.hs.config.email_password_reset_behaviour == "remote"
+        ):
+            threepid = yield identity_handler.threepid_from_creds(threepid_creds)
+        elif self.hs.config.email_password_reset_behaviour == "local":
+            row = yield self.store.get_threepid_validation_session(
+                medium,
+                threepid_creds["client_secret"],
+                sid=threepid_creds["sid"],
+            )
+
+            threepid = {
+                "medium": row["medium"],
+                "address": row["address"],
+                "validated_at": row["validated_at"],
+            } if row else None
+
+            if row:
+                # Valid threepid returned, delete from the db
+                yield self.store.delete_threepid_session(threepid_creds["sid"])
+        else:
+            raise SynapseError(400, "Password resets are not enabled on this homeserver")
 
         if not threepid:
             raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)