summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py37
1 files changed, 22 insertions, 15 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 8619fbb982..ff103cbb92 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,10 +18,20 @@ import logging
 import time
 import unicodedata
 import urllib.parse
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 import attr
-import bcrypt  # type: ignore[import]
+import bcrypt
 import pymacaroons
 
 from synapse.api.constants import LoginType
@@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -149,11 +162,7 @@ class SsoLoginExtraAttributes:
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.checkers = {}  # type: Dict[str, UserInteractiveAuthChecker]
@@ -470,9 +479,7 @@ class AuthHandler(BaseHandler):
             # authentication flow.
             await self.store.set_ui_auth_clientdict(sid, clientdict)
 
-        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
-            0
-        ].decode("ascii", "surrogateescape")
+        user_agent = request.get_user_agent("")
 
         await self.store.add_user_agent_ip_to_ui_auth_session(
             session.session_id, user_agent, clientip
@@ -692,7 +699,7 @@ class AuthHandler(BaseHandler):
         Creates a new access token for the user with the given user ID.
 
         The user is assumed to have been authenticated by some other
-        machanism (e.g. CAS), and the user_id converted to the canonical case.
+        mechanism (e.g. CAS), and the user_id converted to the canonical case.
 
         The device will be recorded in the table if it is not there already.
 
@@ -984,17 +991,17 @@ class AuthHandler(BaseHandler):
                 # This might return an awaitable, if it does block the log out
                 # until it completes.
                 result = provider.on_logged_out(
-                    user_id=str(user_info["user"]),
-                    device_id=user_info["device_id"],
+                    user_id=user_info.user_id,
+                    device_id=user_info.device_id,
                     access_token=access_token,
                 )
                 if inspect.isawaitable(result):
                     await result
 
         # delete pushers associated with this access token
-        if user_info["token_id"] is not None:
+        if user_info.token_id is not None:
             await self.hs.get_pusherpool().remove_pushers_by_access_token(
-                str(user_info["user"]), (user_info["token_id"],)
+                user_info.user_id, (user_info.token_id,)
             )
 
     async def delete_access_tokens_for_user(