summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7935.misc1
-rw-r--r--docs/password_auth_providers.md187
-rw-r--r--synapse/handlers/auth.py7
-rw-r--r--synapse/handlers/ui_auth/checkers.py35
4 files changed, 118 insertions, 112 deletions
diff --git a/changelog.d/7935.misc b/changelog.d/7935.misc
new file mode 100644
index 0000000000..3771f99bf2
--- /dev/null
+++ b/changelog.d/7935.misc
@@ -0,0 +1 @@
+Convert the auth providers to be async/await.
diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md
index 5d9ae67041..fef1d47e85 100644
--- a/docs/password_auth_providers.md
+++ b/docs/password_auth_providers.md
@@ -19,102 +19,103 @@ password auth provider module implementations:
 
 Password auth provider classes must provide the following methods:
 
-*class* `SomeProvider.parse_config`(*config*)
+* `parse_config(config)`
+  This method is passed the `config` object for this module from the
+  homeserver configuration file.
 
-> This method is passed the `config` object for this module from the
-> homeserver configuration file.
->
-> It should perform any appropriate sanity checks on the provided
-> configuration, and return an object which is then passed into
-> `__init__`.
+  It should perform any appropriate sanity checks on the provided
+  configuration, and return an object which is then passed into
 
-*class* `SomeProvider`(*config*, *account_handler*)
+  This method should have the `@staticmethod` decoration.
 
-> The constructor is passed the config object returned by
-> `parse_config`, and a `synapse.module_api.ModuleApi` object which
-> allows the password provider to check if accounts exist and/or create
-> new ones.
+* `__init__(self, config, account_handler)`
+
+  The constructor is passed the config object returned by
+  `parse_config`, and a `synapse.module_api.ModuleApi` object which
+  allows the password provider to check if accounts exist and/or create
+  new ones.
 
 ## Optional methods
 
-Password auth provider classes may optionally provide the following
-methods.
-
-*class* `SomeProvider.get_db_schema_files`()
-
-> This method, if implemented, should return an Iterable of
-> `(name, stream)` pairs of database schema files. Each file is applied
-> in turn at initialisation, and a record is then made in the database
-> so that it is not re-applied on the next start.
-
-`someprovider.get_supported_login_types`()
-
-> This method, if implemented, should return a `dict` mapping from a
-> login type identifier (such as `m.login.password`) to an iterable
-> giving the fields which must be provided by the user in the submission
-> to the `/login` api. These fields are passed in the `login_dict`
-> dictionary to `check_auth`.
->
-> For example, if a password auth provider wants to implement a custom
-> login type of `com.example.custom_login`, where the client is expected
-> to pass the fields `secret1` and `secret2`, the provider should
-> implement this method and return the following dict:
->
->     {"com.example.custom_login": ("secret1", "secret2")}
-
-`someprovider.check_auth`(*username*, *login_type*, *login_dict*)
-
-> This method is the one that does the real work. If implemented, it
-> will be called for each login attempt where the login type matches one
-> of the keys returned by `get_supported_login_types`.
->
-> It is passed the (possibly UNqualified) `user` provided by the client,
-> the login type, and a dictionary of login secrets passed by the
-> client.
->
-> The method should return a Twisted `Deferred` object, which resolves
-> to the canonical `@localpart:domain` user id if authentication is
-> successful, and `None` if not.
->
-> Alternatively, the `Deferred` can resolve to a `(str, func)` tuple, in
-> which case the second field is a callback which will be called with
-> the result from the `/login` call (including `access_token`,
-> `device_id`, etc.)
-
-`someprovider.check_3pid_auth`(*medium*, *address*, *password*)
-
-> This method, if implemented, is called when a user attempts to
-> register or log in with a third party identifier, such as email. It is
-> passed the medium (ex. "email"), an address (ex.
-> "<jdoe@example.com>") and the user's password.
->
-> The method should return a Twisted `Deferred` object, which resolves
-> to a `str` containing the user's (canonical) User ID if
-> authentication was successful, and `None` if not.
->
-> As with `check_auth`, the `Deferred` may alternatively resolve to a
-> `(user_id, callback)` tuple.
-
-`someprovider.check_password`(*user_id*, *password*)
-
-> This method provides a simpler interface than
-> `get_supported_login_types` and `check_auth` for password auth
-> providers that just want to provide a mechanism for validating
-> `m.login.password` logins.
->
-> Iif implemented, it will be called to check logins with an
-> `m.login.password` login type. It is passed a qualified
-> `@localpart:domain` user id, and the password provided by the user.
->
-> The method should return a Twisted `Deferred` object, which resolves
-> to `True` if authentication is successful, and `False` if not.
-
-`someprovider.on_logged_out`(*user_id*, *device_id*, *access_token*)
-
-> This method, if implemented, is called when a user logs out. It is
-> passed the qualified user ID, the ID of the deactivated device (if
-> any: access tokens are occasionally created without an associated
-> device ID), and the (now deactivated) access token.
->
-> It may return a Twisted `Deferred` object; the logout request will
-> wait for the deferred to complete but the result is ignored.
+Password auth provider classes may optionally provide the following methods:
+
+* `get_db_schema_files(self)`
+
+  This method, if implemented, should return an Iterable of
+  `(name, stream)` pairs of database schema files. Each file is applied
+  in turn at initialisation, and a record is then made in the database
+  so that it is not re-applied on the next start.
+
+* `get_supported_login_types(self)`
+
+  This method, if implemented, should return a `dict` mapping from a
+  login type identifier (such as `m.login.password`) to an iterable
+  giving the fields which must be provided by the user in the submission
+  to [the `/login` API](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login).
+  These fields are passed in the `login_dict` dictionary to `check_auth`.
+
+  For example, if a password auth provider wants to implement a custom
+  login type of `com.example.custom_login`, where the client is expected
+  to pass the fields `secret1` and `secret2`, the provider should
+  implement this method and return the following dict:
+
+  ```python
+  {"com.example.custom_login": ("secret1", "secret2")}
+  ```
+
+* `check_auth(self, username, login_type, login_dict)`
+
+  This method does the real work. If implemented, it
+  will be called for each login attempt where the login type matches one
+  of the keys returned by `get_supported_login_types`.
+
+  It is passed the (possibly unqualified) `user` field provided by the client,
+  the login type, and a dictionary of login secrets passed by the
+  client.
+
+  The method should return an `Awaitable` object, which resolves
+  to the canonical `@localpart:domain` user ID if authentication is
+  successful, and `None` if not.
+
+  Alternatively, the `Awaitable` can resolve to a `(str, func)` tuple, in
+  which case the second field is a callback which will be called with
+  the result from the `/login` call (including `access_token`,
+  `device_id`, etc.)
+
+* `check_3pid_auth(self, medium, address, password)`
+
+  This method, if implemented, is called when a user attempts to
+  register or log in with a third party identifier, such as email. It is
+  passed the medium (ex. "email"), an address (ex.
+  "<jdoe@example.com>") and the user's password.
+
+  The method should return an `Awaitable` object, which resolves
+  to a `str` containing the user's (canonical) User id if
+  authentication was successful, and `None` if not.
+
+  As with `check_auth`, the `Awaitable` may alternatively resolve to a
+  `(user_id, callback)` tuple.
+
+* `check_password(self, user_id, password)`
+
+  This method provides a simpler interface than
+  `get_supported_login_types` and `check_auth` for password auth
+  providers that just want to provide a mechanism for validating
+  `m.login.password` logins.
+
+  If implemented, it will be called to check logins with an
+  `m.login.password` login type. It is passed a qualified
+  `@localpart:domain` user id, and the password provided by the user.
+
+  The method should return an `Awaitable` object, which resolves
+  to `True` if authentication is successful, and `False` if not.
+
+* `on_logged_out(self, user_id, device_id, access_token)`
+
+  This method, if implemented, is called when a user logs out. It is
+  passed the qualified user ID, the ID of the deactivated device (if
+  any: access tokens are occasionally created without an associated
+  device ID), and the (now deactivated) access token.
+
+  It may return an `Awaitable` object; the logout request will
+  wait for the `Awaitable` to complete, but the result is ignored.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a162392e4c..c7d921c21a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -13,6 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import inspect
 import logging
 import time
 import unicodedata
@@ -863,11 +864,15 @@ class AuthHandler(BaseHandler):
         # see if any of our auth providers want to know about this
         for provider in self.password_providers:
             if hasattr(provider, "on_logged_out"):
-                await provider.on_logged_out(
+                # This might return an awaitable, if it does block the log out
+                # until it completes.
+                result = provider.on_logged_out(
                     user_id=str(user_info["user"]),
                     device_id=user_info["device_id"],
                     access_token=access_token,
                 )
+                if inspect.isawaitable(result):
+                    await result
 
         # delete pushers associated with this access token
         if user_info["token_id"] is not None:
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index a140e9391e..a011e9fe29 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -14,10 +14,10 @@
 # limitations under the License.
 
 import logging
+from typing import Any
 
 from canonicaljson import json
 
-from twisted.internet import defer
 from twisted.web.client import PartialDownloadError
 
 from synapse.api.constants import LoginType
@@ -33,25 +33,25 @@ class UserInteractiveAuthChecker:
     def __init__(self, hs):
         pass
 
-    def is_enabled(self):
+    def is_enabled(self) -> bool:
         """Check if the configuration of the homeserver allows this checker to work
 
         Returns:
-            bool: True if this login type is enabled.
+            True if this login type is enabled.
         """
 
-    def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         """Given the authentication dict from the client, attempt to check this step
 
         Args:
-            authdict (dict): authentication dictionary from the client
-            clientip (str): The IP address of the client.
+            authdict: authentication dictionary from the client
+            clientip: The IP address of the client.
 
         Raises:
             SynapseError if authentication failed
 
         Returns:
-            Deferred: the result of authentication (to pass back to the client?)
+            The result of authentication (to pass back to the client?)
         """
         raise NotImplementedError()
 
@@ -62,8 +62,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
     def is_enabled(self):
         return True
 
-    def check_auth(self, authdict, clientip):
-        return defer.succeed(True)
+    async def check_auth(self, authdict, clientip):
+        return True
 
 
 class TermsAuthChecker(UserInteractiveAuthChecker):
@@ -72,8 +72,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
     def is_enabled(self):
         return True
 
-    def check_auth(self, authdict, clientip):
-        return defer.succeed(True)
+    async def check_auth(self, authdict, clientip):
+        return True
 
 
 class RecaptchaAuthChecker(UserInteractiveAuthChecker):
@@ -89,8 +89,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
     def is_enabled(self):
         return self._enabled
 
-    @defer.inlineCallbacks
-    def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict, clientip):
         try:
             user_response = authdict["response"]
         except KeyError:
@@ -107,7 +106,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
         # TODO: get this from the homeserver rather than creating a new one for
         # each request
         try:
-            resp_body = yield self._http_client.post_urlencoded_get_json(
+            resp_body = await self._http_client.post_urlencoded_get_json(
                 self._url,
                 args={
                     "secret": self._secret,
@@ -219,8 +218,8 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
             ThreepidBehaviour.LOCAL,
         )
 
-    def check_auth(self, authdict, clientip):
-        return defer.ensureDeferred(self._check_threepid("email", authdict))
+    async def check_auth(self, authdict, clientip):
+        return await self._check_threepid("email", authdict)
 
 
 class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
@@ -233,8 +232,8 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
     def is_enabled(self):
         return bool(self.hs.config.account_threepid_delegate_msisdn)
 
-    def check_auth(self, authdict, clientip):
-        return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
+    async def check_auth(self, authdict, clientip):
+        return await self._check_threepid("msisdn", authdict)
 
 
 INTERACTIVE_AUTH_CHECKERS = [