summary refs log tree commit diff
path: root/synapse/handlers/ui_auth/checkers.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/ui_auth/checkers.py')
-rw-r--r--synapse/handlers/ui_auth/checkers.py41
1 files changed, 20 insertions, 21 deletions
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 8b24a73319..9146dc1a3b 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -12,16 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import logging
 
-from canonicaljson import json
+import logging
+from typing import Any
 
-from twisted.internet import defer
 from twisted.web.client import PartialDownloadError
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -32,25 +32,25 @@ class UserInteractiveAuthChecker:
     def __init__(self, hs):
         pass
 
-    def is_enabled(self):
+    def is_enabled(self) -> bool:
         """Check if the configuration of the homeserver allows this checker to work
 
         Returns:
-            bool: True if this login type is enabled.
+            True if this login type is enabled.
         """
 
-    def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         """Given the authentication dict from the client, attempt to check this step
 
         Args:
-            authdict (dict): authentication dictionary from the client
-            clientip (str): The IP address of the client.
+            authdict: authentication dictionary from the client
+            clientip: The IP address of the client.
 
         Raises:
             SynapseError if authentication failed
 
         Returns:
-            Deferred: the result of authentication (to pass back to the client?)
+            The result of authentication (to pass back to the client?)
         """
         raise NotImplementedError()
 
@@ -61,8 +61,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
     def is_enabled(self):
         return True
 
-    def check_auth(self, authdict, clientip):
-        return defer.succeed(True)
+    async def check_auth(self, authdict, clientip):
+        return True
 
 
 class TermsAuthChecker(UserInteractiveAuthChecker):
@@ -71,8 +71,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
     def is_enabled(self):
         return True
 
-    def check_auth(self, authdict, clientip):
-        return defer.succeed(True)
+    async def check_auth(self, authdict, clientip):
+        return True
 
 
 class RecaptchaAuthChecker(UserInteractiveAuthChecker):
@@ -88,8 +88,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
     def is_enabled(self):
         return self._enabled
 
-    @defer.inlineCallbacks
-    def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict, clientip):
         try:
             user_response = authdict["response"]
         except KeyError:
@@ -106,7 +105,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
         # TODO: get this from the homeserver rather than creating a new one for
         # each request
         try:
-            resp_body = yield self._http_client.post_urlencoded_get_json(
+            resp_body = await self._http_client.post_urlencoded_get_json(
                 self._url,
                 args={
                     "secret": self._secret,
@@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
         except PartialDownloadError as pde:
             # Twisted is silly
             data = pde.response
-            resp_body = json.loads(data)
+            resp_body = json_decoder.decode(data.decode("utf-8"))
 
         if "success" in resp_body:
             # Note that we do NOT check the hostname here: we explicitly
@@ -218,8 +217,8 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
             ThreepidBehaviour.LOCAL,
         )
 
-    def check_auth(self, authdict, clientip):
-        return defer.ensureDeferred(self._check_threepid("email", authdict))
+    async def check_auth(self, authdict, clientip):
+        return await self._check_threepid("email", authdict)
 
 
 class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
@@ -232,8 +231,8 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
     def is_enabled(self):
         return bool(self.hs.config.account_threepid_delegate_msisdn)
 
-    def check_auth(self, authdict, clientip):
-        return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
+    async def check_auth(self, authdict, clientip):
+        return await self._check_threepid("msisdn", authdict)
 
 
 INTERACTIVE_AUTH_CHECKERS = [