diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 8815f685b9..c7dc07008a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,6 +26,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Tuple,
Union,
@@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
# better way to break the loop
account_handler = ModuleApi(hs, self)
- self.password_providers = []
- for module, config in hs.config.password_providers:
- try:
- self.password_providers.append(
- module(config=config, account_handler=account_handler)
- )
- except Exception as e:
- logger.error("Error while initializing %r: %s", module, e)
- raise
+ self.password_providers = [
+ PasswordProvider.load(module, config, account_handler)
+ for module, config in hs.config.password_providers
+ ]
- logger.info("Extra password_providers: %r", self.password_providers)
+ logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
@@ -853,6 +850,8 @@ class AuthHandler(BaseHandler):
LoginError if there was an authentication problem.
"""
login_type = login_submission.get("type")
+ if not isinstance(login_type, str):
+ raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
# ideally, we wouldn't be checking the identifier unless we know we have a login
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
@@ -998,24 +997,12 @@ class AuthHandler(BaseHandler):
qualified_user_id = UserID(username, self.hs.hostname).to_string()
login_type = login_submission.get("type")
+ # we already checked that we have a valid login type
+ assert isinstance(login_type, str)
+
known_login_type = False
for provider in self.password_providers:
- if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
- known_login_type = True
- # we've already checked that there is a (valid) password field
- is_valid = await provider.check_password(
- qualified_user_id, login_submission["password"]
- )
- if is_valid:
- return qualified_user_id, None
-
- if not hasattr(provider, "get_supported_login_types") or not hasattr(
- provider, "check_auth"
- ):
- # this password provider doesn't understand custom login types
- continue
-
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
@@ -1040,8 +1027,6 @@ class AuthHandler(BaseHandler):
result = await provider.check_auth(username, login_type, login_dict)
if result:
- if isinstance(result, str):
- result = (result, None)
return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
@@ -1083,19 +1068,9 @@ class AuthHandler(BaseHandler):
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
- if hasattr(provider, "check_3pid_auth"):
- # This function is able to return a deferred that either
- # resolves None, meaning authentication failure, or upon
- # success, to a str (which is the user_id) or a tuple of
- # (user_id, callback_func), where callback_func should be run
- # after we've finished everything else
- result = await provider.check_3pid_auth(medium, address, password)
- if result:
- # Check if the return value is a str or a tuple
- if isinstance(result, str):
- # If it's a str, set callback function to None
- result = (result, None)
- return result
+ result = await provider.check_3pid_auth(medium, address, password)
+ if result:
+ return result
return None, None
@@ -1153,16 +1128,11 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- # This might return an awaitable, if it does block the log out
- # until it completes.
- result = provider.on_logged_out(
- user_id=user_info.user_id,
- device_id=user_info.device_id,
- access_token=access_token,
- )
- if inspect.isawaitable(result):
- await result
+ await provider.on_logged_out(
+ user_id=user_info.user_id,
+ device_id=user_info.device_id,
+ access_token=access_token,
+ )
# delete pushers associated with this access token
if user_info.token_id is not None:
@@ -1191,11 +1161,10 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- for token, token_id, device_id in tokens_and_devices:
- await provider.on_logged_out(
- user_id=user_id, device_id=device_id, access_token=token
- )
+ for token, token_id, device_id in tokens_and_devices:
+ await provider.on_logged_out(
+ user_id=user_id, device_id=device_id, access_token=token
+ )
# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1519,3 +1488,127 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
+
+
+class PasswordProvider:
+ """Wrapper for a password auth provider module
+
+ This class abstracts out all of the backwards-compatibility hacks for
+ password providers, to provide a consistent interface.
+ """
+
+ @classmethod
+ def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+ try:
+ pp = module(config=config, account_handler=module_api)
+ except Exception as e:
+ logger.error("Error while initializing %r: %s", module, e)
+ raise
+ return cls(pp, module_api)
+
+ def __init__(self, pp, module_api: ModuleApi):
+ self._pp = pp
+ self._module_api = module_api
+
+ self._supported_login_types = {}
+
+ # grandfather in check_password support
+ if hasattr(self._pp, "check_password"):
+ self._supported_login_types[LoginType.PASSWORD] = ("password",)
+
+ g = getattr(self._pp, "get_supported_login_types", None)
+ if g:
+ self._supported_login_types.update(g())
+
+ def __str__(self):
+ return str(self._pp)
+
+ def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
+ """Get the login types supported by this password provider
+
+ Returns a map from a login type identifier (such as m.login.password) to an
+ iterable giving the fields which must be provided by the user in the submission
+ to the /login API.
+
+ This wrapper adds m.login.password to the list if the underlying password
+ provider supports the check_password() api.
+ """
+ return self._supported_login_types
+
+ async def check_auth(
+ self, username: str, login_type: str, login_dict: JsonDict
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ """Check if the user has presented valid login credentials
+
+ This wrapper also calls check_password() if the underlying password provider
+ supports the check_password() api and the login type is m.login.password.
+
+ Args:
+ username: user id presented by the client. Either an MXID or an unqualified
+ username.
+
+ login_type: the login type being attempted - one of the types returned by
+ get_supported_login_types()
+
+ login_dict: the dictionary of login secrets passed by the client.
+
+ Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
+ user, and `callback` is an optional callback which will be called with the
+ result from the /login call (including access_token, device_id, etc.)
+ """
+ # first grandfather in a call to check_password
+ if login_type == LoginType.PASSWORD:
+ g = getattr(self._pp, "check_password", None)
+ if g:
+ qualified_user_id = self._module_api.get_qualified_user_id(username)
+ is_valid = await self._pp.check_password(
+ qualified_user_id, login_dict["password"]
+ )
+ if is_valid:
+ return qualified_user_id, None
+
+ g = getattr(self._pp, "check_auth", None)
+ if not g:
+ return None
+ result = await g(username, login_type, login_dict)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def check_3pid_auth(
+ self, medium: str, address: str, password: str
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ g = getattr(self._pp, "check_3pid_auth", None)
+ if not g:
+ return None
+
+ # This function is able to return a deferred that either
+ # resolves None, meaning authentication failure, or upon
+ # success, to a str (which is the user_id) or a tuple of
+ # (user_id, callback_func), where callback_func should be run
+ # after we've finished everything else
+ result = await g(medium, address, password)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def on_logged_out(
+ self, user_id: str, device_id: Optional[str], access_token: str
+ ) -> None:
+ g = getattr(self._pp, "on_logged_out", None)
+ if not g:
+ return
+
+ # This might return an awaitable, if it does block the log out
+ # until it completes.
+ result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ if inspect.isawaitable(result):
+ await result
|