diff options
author | Richard van der Hoff <1389908+richvdh@users.noreply.github.com> | 2021-01-27 21:28:59 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-01-27 21:28:59 +0000 |
commit | 869667760f571c9edebab660061e17035d57f182 (patch) | |
tree | 0ea90b7f0a234fb47fcafe130f21fadd780124e4 /synapse/handlers/oidc_handler.py | |
parent | Merge tag 'v1.26.0' into social_login (diff) | |
download | synapse-869667760f571c9edebab660061e17035d57f182.tar.xz |
Support for scraping email addresses from OIDC providers (#9245)
Diffstat (limited to 'synapse/handlers/oidc_handler.py')
-rw-r--r-- | synapse/handlers/oidc_handler.py | 52 |
1 files changed, 28 insertions, 24 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 1607e12935..324ddb798c 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -1056,7 +1056,8 @@ class OidcSessionData: UserAttributeDict = TypedDict( - "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]} + "UserAttributeDict", + {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]}, ) C = TypeVar("C") @@ -1135,11 +1136,12 @@ def jinja_finalize(thing): env = Environment(finalize=jinja_finalize) -@attr.s +@attr.s(slots=True, frozen=True) class JinjaOidcMappingConfig: subject_claim = attr.ib(type=str) localpart_template = attr.ib(type=Optional[Template]) display_name_template = attr.ib(type=Optional[Template]) + email_template = attr.ib(type=Optional[Template]) extra_attributes = attr.ib(type=Dict[str, Template]) @@ -1156,23 +1158,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") - localpart_template = None # type: Optional[Template] - if "localpart_template" in config: + def parse_template_config(option_name: str) -> Optional[Template]: + if option_name not in config: + return None try: - localpart_template = env.from_string(config["localpart_template"]) + return env.from_string(config[option_name]) except Exception as e: - raise ConfigError( - "invalid jinja template", path=["localpart_template"] - ) from e + raise ConfigError("invalid jinja template", path=[option_name]) from e - display_name_template = None # type: Optional[Template] - if "display_name_template" in config: - try: - display_name_template = env.from_string(config["display_name_template"]) - except Exception as e: - raise ConfigError( - "invalid jinja template", path=["display_name_template"] - ) from e + localpart_template = parse_template_config("localpart_template") + display_name_template = parse_template_config("display_name_template") + email_template = parse_template_config("email_template") extra_attributes = {} # type Dict[str, Template] if "extra_attributes" in config: @@ -1192,6 +1188,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): subject_claim=subject_claim, localpart_template=localpart_template, display_name_template=display_name_template, + email_template=email_template, extra_attributes=extra_attributes, ) @@ -1213,16 +1210,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): # a usable mxid. localpart += str(failures) if failures else "" - display_name = None # type: Optional[str] - if self._config.display_name_template is not None: - display_name = self._config.display_name_template.render( - user=userinfo - ).strip() + def render_template_field(template: Optional[Template]) -> Optional[str]: + if template is None: + return None + return template.render(user=userinfo).strip() + + display_name = render_template_field(self._config.display_name_template) + if display_name == "": + display_name = None - if display_name == "": - display_name = None + emails = [] # type: List[str] + email = render_template_field(self._config.email_template) + if email: + emails.append(email) - return UserAttributeDict(localpart=localpart, display_name=display_name) + return UserAttributeDict( + localpart=localpart, display_name=display_name, emails=emails + ) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: extras = {} # type: Dict[str, str] |