summary refs log tree commit diff
path: root/synapse/handlers/oidc_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/oidc_handler.py')
-rw-r--r--synapse/handlers/oidc_handler.py32
1 files changed, 22 insertions, 10 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index fa5ee5de8f..c5bd2fea68 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import json
 import logging
 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
 from urllib.parse import urlencode
@@ -38,8 +37,8 @@ from synapse.config import ConfigError
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
-from synapse.push.mailer import load_jinja2_templates
 from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.util import json_decoder
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -94,6 +93,7 @@ class OidcHandler:
     """
 
     def __init__(self, hs: "HomeServer"):
+        self.hs = hs
         self._callback_url = hs.config.oidc_callback_url  # type: str
         self._scopes = hs.config.oidc_scopes  # type: List[str]
         self._client_auth = ClientAuth(
@@ -123,9 +123,7 @@ class OidcHandler:
         self._hostname = hs.hostname  # type: str
         self._server_name = hs.config.server_name  # type: str
         self._macaroon_secret_key = hs.config.macaroon_secret_key
-        self._error_template = load_jinja2_templates(
-            hs.config.sso_template_dir, ["sso_error.html"]
-        )[0]
+        self._error_template = hs.config.sso_error_template
 
         # identifier for the external_ids table
         self._auth_provider_id = "oidc"
@@ -370,7 +368,7 @@ class OidcHandler:
             # and check for an error field. If not, we respond with a generic
             # error message.
             try:
-                resp = json.loads(resp_body.decode("utf-8"))
+                resp = json_decoder.decode(resp_body.decode("utf-8"))
                 error = resp["error"]
                 description = resp.get("error_description", error)
             except (ValueError, KeyError):
@@ -387,7 +385,7 @@ class OidcHandler:
 
         # Since it is a not a 5xx code, body should be a valid JSON. It will
         # raise if not.
-        resp = json.loads(resp_body.decode("utf-8"))
+        resp = json_decoder.decode(resp_body.decode("utf-8"))
 
         if "error" in resp:
             error = resp["error"]
@@ -692,9 +690,17 @@ class OidcHandler:
                 self._render_error(request, "invalid_token", str(e))
                 return
 
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+            0
+        ].decode("ascii", "surrogateescape")
+        ip_address = self.hs.get_ip_from_request(request)
+
         # Call the mapper to register/login the user
         try:
-            user_id = await self._map_userinfo_to_user(userinfo, token)
+            user_id = await self._map_userinfo_to_user(
+                userinfo, token, user_agent, ip_address
+            )
         except MappingException as e:
             logger.exception("Could not map user")
             self._render_error(request, "mapping_error", str(e))
@@ -831,7 +837,9 @@ class OidcHandler:
         now = self._clock.time_msec()
         return now < expiry
 
-    async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
+    async def _map_userinfo_to_user(
+        self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
+    ) -> str:
         """Maps a UserInfo object to a mxid.
 
         UserInfo should have a claim that uniquely identifies users. This claim
@@ -846,6 +854,8 @@ class OidcHandler:
         Args:
             userinfo: an object representing the user
             token: a dict with the tokens obtained from the provider
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
 
         Raises:
             MappingException: if there was an error while mapping some properties
@@ -902,7 +912,9 @@ class OidcHandler:
         # It's the first time this user is logging in and the mapped mxid was
         # not taken, register the user
         registered_user_id = await self._registration_handler.register_user(
-            localpart=localpart, default_display_name=attributes["display_name"],
+            localpart=localpart,
+            default_display_name=attributes["display_name"],
+            user_agent_ips=(user_agent, ip_address),
         )
 
         await self._datastore.record_user_external_id(