summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2021-01-28 22:08:11 +0000
committerRichard van der Hoff <richard@matrix.org>2021-01-28 22:08:11 +0000
commit0d81a6fa3e1dc832f56ed09805229b9089758ba5 (patch)
tree08a29e6210ef63c506f8a50944163cdabd400a8b /synapse/handlers
parentRatelimit 3PID /requestToken API (#9238) (diff)
parentAdd 'brand' field to MSC2858 response (#9242) (diff)
downloadsynapse-0d81a6fa3e1dc832f56ed09805229b9089758ba5.tar.xz
Merge branch 'social_login' into develop
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/cas_handler.py3
-rw-r--r--synapse/handlers/oidc_handler.py55
-rw-r--r--synapse/handlers/saml_handler.py3
-rw-r--r--synapse/handlers/sso.py5
4 files changed, 40 insertions, 26 deletions
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 21b6bc4992..bd35d1fb87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -80,9 +80,10 @@ class CasHandler:
         # user-facing name of this auth provider
         self.idp_name = "CAS"
 
-        # we do not currently support icons for CAS auth, but this is required by
+        # we do not currently support brands/icons for CAS auth, but this is required by
         # the SsoIdentityProvider protocol type.
         self.idp_icon = None
+        self.idp_brand = None
 
         self._sso_handler = hs.get_sso_handler()
 
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 1607e12935..ca647fa78f 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -274,6 +274,9 @@ class OidcProvider:
         # MXC URI for icon for this auth provider
         self.idp_icon = provider.idp_icon
 
+        # optional brand identifier for this auth provider
+        self.idp_brand = provider.idp_brand
+
         self._sso_handler = hs.get_sso_handler()
 
         self._sso_handler.register_identity_provider(self)
@@ -1056,7 +1059,8 @@ class OidcSessionData:
 
 
 UserAttributeDict = TypedDict(
-    "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
+    "UserAttributeDict",
+    {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
 )
 C = TypeVar("C")
 
@@ -1135,11 +1139,12 @@ def jinja_finalize(thing):
 env = Environment(finalize=jinja_finalize)
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class JinjaOidcMappingConfig:
     subject_claim = attr.ib(type=str)
     localpart_template = attr.ib(type=Optional[Template])
     display_name_template = attr.ib(type=Optional[Template])
+    email_template = attr.ib(type=Optional[Template])
     extra_attributes = attr.ib(type=Dict[str, Template])
 
 
@@ -1156,23 +1161,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     def parse_config(config: dict) -> JinjaOidcMappingConfig:
         subject_claim = config.get("subject_claim", "sub")
 
-        localpart_template = None  # type: Optional[Template]
-        if "localpart_template" in config:
+        def parse_template_config(option_name: str) -> Optional[Template]:
+            if option_name not in config:
+                return None
             try:
-                localpart_template = env.from_string(config["localpart_template"])
+                return env.from_string(config[option_name])
             except Exception as e:
-                raise ConfigError(
-                    "invalid jinja template", path=["localpart_template"]
-                ) from e
+                raise ConfigError("invalid jinja template", path=[option_name]) from e
 
-        display_name_template = None  # type: Optional[Template]
-        if "display_name_template" in config:
-            try:
-                display_name_template = env.from_string(config["display_name_template"])
-            except Exception as e:
-                raise ConfigError(
-                    "invalid jinja template", path=["display_name_template"]
-                ) from e
+        localpart_template = parse_template_config("localpart_template")
+        display_name_template = parse_template_config("display_name_template")
+        email_template = parse_template_config("email_template")
 
         extra_attributes = {}  # type Dict[str, Template]
         if "extra_attributes" in config:
@@ -1192,6 +1191,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
             subject_claim=subject_claim,
             localpart_template=localpart_template,
             display_name_template=display_name_template,
+            email_template=email_template,
             extra_attributes=extra_attributes,
         )
 
@@ -1213,16 +1213,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
             # a usable mxid.
             localpart += str(failures) if failures else ""
 
-        display_name = None  # type: Optional[str]
-        if self._config.display_name_template is not None:
-            display_name = self._config.display_name_template.render(
-                user=userinfo
-            ).strip()
+        def render_template_field(template: Optional[Template]) -> Optional[str]:
+            if template is None:
+                return None
+            return template.render(user=userinfo).strip()
+
+        display_name = render_template_field(self._config.display_name_template)
+        if display_name == "":
+            display_name = None
 
-            if display_name == "":
-                display_name = None
+        emails = []  # type: List[str]
+        email = render_template_field(self._config.email_template)
+        if email:
+            emails.append(email)
 
-        return UserAttributeDict(localpart=localpart, display_name=display_name)
+        return UserAttributeDict(
+            localpart=localpart, display_name=display_name, emails=emails
+        )
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
         extras = {}  # type: Dict[str, str]
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 38461cf79d..5946919c33 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -78,9 +78,10 @@ class SamlHandler(BaseHandler):
         # user-facing name of this auth provider
         self.idp_name = "SAML"
 
-        # we do not currently support icons for SAML auth, but this is required by
+        # we do not currently support icons/brands for SAML auth, but this is required by
         # the SsoIdentityProvider protocol type.
         self.idp_icon = None
+        self.idp_brand = None
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index afc1341d09..3308b037d2 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -80,6 +80,11 @@ class SsoIdentityProvider(Protocol):
         """Optional MXC URI for user-facing icon"""
         return None
 
+    @property
+    def idp_brand(self) -> Optional[str]:
+        """Optional branding identifier"""
+        return None
+
     @abc.abstractmethod
     async def handle_redirect_request(
         self,