diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index bff87fabde..a182ce22db 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -33,6 +33,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging import opentracing as opentracing
+from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@@ -210,10 +211,10 @@ class Auth:
user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
- user = user_info["user"]
- token_id = user_info["token_id"]
- is_guest = user_info["is_guest"]
- shadow_banned = user_info["shadow_banned"]
+ user = UserID.from_string(user_info.user_id)
+ token_id = user_info.token_id
+ is_guest = user_info.is_guest
+ shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
@@ -223,11 +224,9 @@ class Auth:
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
- # device_id may not be present if get_user_by_access_token has been
- # stubbed out.
- device_id = user_info.get("device_id")
+ device_id = user_info.device_id
- if user and access_token and ip_addr:
+ if access_token and ip_addr:
await self.store.insert_client_ip(
user_id=user.to_string(),
access_token=access_token,
@@ -286,7 +285,7 @@ class Auth:
async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
- ) -> dict:
+ ) -> TokenLookupResult:
""" Validate access token and get user_id from it
Args:
@@ -295,13 +294,7 @@ class Auth:
allow this
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
- Returns:
- dict that includes:
- `user` (UserID)
- `is_guest` (bool)
- `shadow_banned` (bool)
- `token_id` (int|None): access token id. May be None if guest
- `device_id` (str|None): device corresponding to access token
+
Raises:
InvalidClientTokenError if a user by that token exists, but the token is
expired
@@ -311,9 +304,9 @@ class Auth:
if rights == "access":
# first look in the database
- r = await self._look_up_user_by_access_token(token)
+ r = await self.store.get_user_by_access_token(token)
if r:
- valid_until_ms = r["valid_until_ms"]
+ valid_until_ms = r.valid_until_ms
if (
not allow_expired
and valid_until_ms is not None
@@ -330,7 +323,6 @@ class Auth:
# otherwise it needs to be a valid macaroon
try:
user_id, guest = self._parse_and_validate_macaroon(token, rights)
- user = UserID.from_string(user_id)
if rights == "access":
if not guest:
@@ -356,23 +348,17 @@ class Auth:
raise InvalidClientTokenError(
"Guest access token used for regular user"
)
- ret = {
- "user": user,
- "is_guest": True,
- "shadow_banned": False,
- "token_id": None,
+
+ ret = TokenLookupResult(
+ user_id=user_id,
+ is_guest=True,
# all guests get the same device id
- "device_id": GUEST_DEVICE_ID,
- }
+ device_id=GUEST_DEVICE_ID,
+ )
elif rights == "delete_pusher":
# We don't store these tokens in the database
- ret = {
- "user": user,
- "is_guest": False,
- "shadow_banned": False,
- "token_id": None,
- "device_id": None,
- }
+
+ ret = TokenLookupResult(user_id=user_id, is_guest=False)
else:
raise RuntimeError("Unknown rights setting %s", rights)
return ret
@@ -481,24 +467,6 @@ class Auth:
now = self.hs.get_clock().time_msec()
return now < expiry
- async def _look_up_user_by_access_token(self, token):
- ret = await self.store.get_user_by_access_token(token)
- if not ret:
- return None
-
- # we use ret.get() below because *lots* of unit tests stub out
- # get_user_by_access_token in a way where it only returns a couple of
- # the fields.
- user_info = {
- "user": UserID.from_string(ret.get("name")),
- "token_id": ret.get("token_id", None),
- "is_guest": False,
- "shadow_banned": ret.get("shadow_banned"),
- "device_id": ret.get("device_id"),
- "valid_until_ms": ret.get("valid_until_ms"),
- }
- return user_info
-
def get_appservice_by_req(self, request):
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
|