diff options
Diffstat (limited to 'synapse/api/auth.py')
-rw-r--r-- | synapse/api/auth.py | 113 |
1 files changed, 46 insertions, 67 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 526cb58c5f..bfcaf68b2a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging import opentracing as opentracing +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import StateMap, UserID from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -190,10 +191,6 @@ class Auth: user_id, app_service = await self._get_appservice_user_id(request) if user_id: - request.authenticated_entity = user_id - opentracing.set_tag("authenticated_entity", user_id) - opentracing.set_tag("appservice_id", app_service.id) - if ip_addr and self._track_appservice_user_ips: await self.store.insert_client_ip( user_id=user_id, @@ -203,31 +200,38 @@ class Auth: device_id="dummy-device", # stubbed ) - return synapse.types.create_requester(user_id, app_service=app_service) + requester = synapse.types.create_requester( + user_id, app_service=app_service + ) + + request.requester = user_id + opentracing.set_tag("authenticated_entity", user_id) + opentracing.set_tag("user_id", user_id) + opentracing.set_tag("appservice_id", app_service.id) + + return requester user_info = await self.get_user_by_access_token( access_token, rights, allow_expired=allow_expired ) - user = user_info["user"] - token_id = user_info["token_id"] - is_guest = user_info["is_guest"] - shadow_banned = user_info["shadow_banned"] + token_id = user_info.token_id + is_guest = user_info.is_guest + shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: - user_id = user.to_string() - if await self.store.is_account_expired(user_id, self.clock.time_msec()): + if await self.store.is_account_expired( + user_info.user_id, self.clock.time_msec() + ): raise AuthError( 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) - # device_id may not be present if get_user_by_access_token has been - # stubbed out. - device_id = user_info.get("device_id") + device_id = user_info.device_id - if user and access_token and ip_addr: + if access_token and ip_addr: await self.store.insert_client_ip( - user_id=user.to_string(), + user_id=user_info.token_owner, access_token=access_token, ip=ip_addr, user_agent=user_agent, @@ -241,19 +245,23 @@ class Auth: errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) - request.authenticated_entity = user.to_string() - opentracing.set_tag("authenticated_entity", user.to_string()) - if device_id: - opentracing.set_tag("device_id", device_id) - - return synapse.types.create_requester( - user, + requester = synapse.types.create_requester( + user_info.user_id, token_id, is_guest, shadow_banned, device_id, app_service=app_service, + authenticated_entity=user_info.token_owner, ) + + request.requester = requester + opentracing.set_tag("authenticated_entity", user_info.token_owner) + opentracing.set_tag("user_id", user_info.user_id) + if device_id: + opentracing.set_tag("device_id", device_id) + + return requester except KeyError: raise MissingClientTokenError() @@ -284,7 +292,7 @@ class Auth: async def get_user_by_access_token( self, token: str, rights: str = "access", allow_expired: bool = False, - ) -> dict: + ) -> TokenLookupResult: """ Validate access token and get user_id from it Args: @@ -293,13 +301,7 @@ class Auth: allow this allow_expired: If False, raises an InvalidClientTokenError if the token is expired - Returns: - dict that includes: - `user` (UserID) - `is_guest` (bool) - `shadow_banned` (bool) - `token_id` (int|None): access token id. May be None if guest - `device_id` (str|None): device corresponding to access token + Raises: InvalidClientTokenError if a user by that token exists, but the token is expired @@ -309,9 +311,9 @@ class Auth: if rights == "access": # first look in the database - r = await self._look_up_user_by_access_token(token) + r = await self.store.get_user_by_access_token(token) if r: - valid_until_ms = r["valid_until_ms"] + valid_until_ms = r.valid_until_ms if ( not allow_expired and valid_until_ms is not None @@ -328,7 +330,6 @@ class Auth: # otherwise it needs to be a valid macaroon try: user_id, guest = self._parse_and_validate_macaroon(token, rights) - user = UserID.from_string(user_id) if rights == "access": if not guest: @@ -354,23 +355,17 @@ class Auth: raise InvalidClientTokenError( "Guest access token used for regular user" ) - ret = { - "user": user, - "is_guest": True, - "shadow_banned": False, - "token_id": None, + + ret = TokenLookupResult( + user_id=user_id, + is_guest=True, # all guests get the same device id - "device_id": GUEST_DEVICE_ID, - } + device_id=GUEST_DEVICE_ID, + ) elif rights == "delete_pusher": # We don't store these tokens in the database - ret = { - "user": user, - "is_guest": False, - "shadow_banned": False, - "token_id": None, - "device_id": None, - } + + ret = TokenLookupResult(user_id=user_id, is_guest=False) else: raise RuntimeError("Unknown rights setting %s", rights) return ret @@ -479,31 +474,15 @@ class Auth: now = self.hs.get_clock().time_msec() return now < expiry - async def _look_up_user_by_access_token(self, token): - ret = await self.store.get_user_by_access_token(token) - if not ret: - return None - - # we use ret.get() below because *lots* of unit tests stub out - # get_user_by_access_token in a way where it only returns a couple of - # the fields. - user_info = { - "user": UserID.from_string(ret.get("name")), - "token_id": ret.get("token_id", None), - "is_guest": False, - "shadow_banned": ret.get("shadow_banned"), - "device_id": ret.get("device_id"), - "valid_until_ms": ret.get("valid_until_ms"), - } - return user_info - def get_appservice_by_req(self, request): token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) if not service: logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() - request.authenticated_entity = service.sender + request.requester = synapse.types.create_requester( + service.sender, app_service=service + ) return service async def is_server_admin(self, user: UserID) -> bool: |