summary refs log tree commit diff
path: root/synapse/api
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/api')
-rw-r--r--synapse/api/auth.py202
1 files changed, 95 insertions, 107 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 523bad0c55..9a1aea083f 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -37,8 +37,7 @@ from synapse.logging.opentracing import (
     start_active_span,
     trace,
 )
-from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import Requester, create_requester
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -70,14 +69,14 @@ class Auth:
     async def check_user_in_room(
         self,
         room_id: str,
-        user_id: str,
+        requester: Requester,
         allow_departed_users: bool = False,
     ) -> Tuple[str, Optional[str]]:
         """Check if the user is in the room, or was at some point.
         Args:
             room_id: The room to check.
 
-            user_id: The user to check.
+            requester: The user making the request, according to the access token.
 
             current_state: Optional map of the current state of the room.
                 If provided then that map is used to check whether they are a
@@ -94,6 +93,7 @@ class Auth:
             membership event ID of the user.
         """
 
+        user_id = requester.user.to_string()
         (
             membership,
             member_event_id,
@@ -182,96 +182,69 @@ class Auth:
 
             access_token = self.get_access_token_from_request(request)
 
-            (
-                user_id,
-                device_id,
-                app_service,
-            ) = await self._get_appservice_user_id_and_device_id(request)
-            if user_id and app_service:
-                if ip_addr and self._track_appservice_user_ips:
-                    await self.store.insert_client_ip(
-                        user_id=user_id,
-                        access_token=access_token,
-                        ip=ip_addr,
-                        user_agent=user_agent,
-                        device_id="dummy-device"
-                        if device_id is None
-                        else device_id,  # stubbed
-                    )
-
-                requester = create_requester(
-                    user_id, app_service=app_service, device_id=device_id
+            # First check if it could be a request from an appservice
+            requester = await self._get_appservice_user(request)
+            if not requester:
+                # If not, it should be from a regular user
+                requester = await self.get_user_by_access_token(
+                    access_token, allow_expired=allow_expired
                 )
 
-                request.requester = user_id
-                return requester
-
-            user_info = await self.get_user_by_access_token(
-                access_token, allow_expired=allow_expired
-            )
-            token_id = user_info.token_id
-            is_guest = user_info.is_guest
-            shadow_banned = user_info.shadow_banned
-
-            # Deny the request if the user account has expired.
-            if not allow_expired:
-                if await self._account_validity_handler.is_user_expired(
-                    user_info.user_id
-                ):
-                    # Raise the error if either an account validity module has determined
-                    # the account has expired, or the legacy account validity
-                    # implementation is enabled and determined the account has expired
-                    raise AuthError(
-                        403,
-                        "User account has expired",
-                        errcode=Codes.EXPIRED_ACCOUNT,
-                    )
-
-            device_id = user_info.device_id
-
-            if access_token and ip_addr:
+                # Deny the request if the user account has expired.
+                # This check is only done for regular users, not appservice ones.
+                if not allow_expired:
+                    if await self._account_validity_handler.is_user_expired(
+                        requester.user.to_string()
+                    ):
+                        # Raise the error if either an account validity module has determined
+                        # the account has expired, or the legacy account validity
+                        # implementation is enabled and determined the account has expired
+                        raise AuthError(
+                            403,
+                            "User account has expired",
+                            errcode=Codes.EXPIRED_ACCOUNT,
+                        )
+
+            if ip_addr and (
+                not requester.app_service or self._track_appservice_user_ips
+            ):
+                # XXX(quenting): I'm 95% confident that we could skip setting the
+                # device_id to "dummy-device" for appservices, and that the only impact
+                # would be some rows which whould not deduplicate in the 'user_ips'
+                # table during the transition
+                recorded_device_id = (
+                    "dummy-device"
+                    if requester.device_id is None and requester.app_service is not None
+                    else requester.device_id
+                )
                 await self.store.insert_client_ip(
-                    user_id=user_info.token_owner,
+                    user_id=requester.authenticated_entity,
                     access_token=access_token,
                     ip=ip_addr,
                     user_agent=user_agent,
-                    device_id=device_id,
+                    device_id=recorded_device_id,
                 )
+
                 # Track also the puppeted user client IP if enabled and the user is puppeting
                 if (
-                    user_info.user_id != user_info.token_owner
+                    requester.user.to_string() != requester.authenticated_entity
                     and self._track_puppeted_user_ips
                 ):
                     await self.store.insert_client_ip(
-                        user_id=user_info.user_id,
+                        user_id=requester.user.to_string(),
                         access_token=access_token,
                         ip=ip_addr,
                         user_agent=user_agent,
-                        device_id=device_id,
+                        device_id=requester.device_id,
                     )
 
-            if is_guest and not allow_guest:
+            if requester.is_guest and not allow_guest:
                 raise AuthError(
                     403,
                     "Guest access not allowed",
                     errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                 )
 
-            # Mark the token as used. This is used to invalidate old refresh
-            # tokens after some time.
-            if not user_info.token_used and token_id is not None:
-                await self.store.mark_access_token_as_used(token_id)
-
-            requester = create_requester(
-                user_info.user_id,
-                token_id,
-                is_guest,
-                shadow_banned,
-                device_id,
-                app_service=app_service,
-                authenticated_entity=user_info.token_owner,
-            )
-
             request.requester = requester
             return requester
         except KeyError:
@@ -308,9 +281,7 @@ class Auth:
                 403, "Application service has not registered this user (%s)" % user_id
             )
 
-    async def _get_appservice_user_id_and_device_id(
-        self, request: Request
-    ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]:
+    async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
         """
         Given a request, reads the request parameters to determine:
         - whether it's an application service that's making this request
@@ -325,15 +296,13 @@ class Auth:
              Must use `org.matrix.msc3202.device_id` in place of `device_id` for now.
 
         Returns:
-            3-tuple of
-            (user ID?, device ID?, application service?)
+            the application service `Requester` of that request
 
         Postconditions:
-        - If an application service is returned, so is a user ID
-        - A user ID is never returned without an application service
-        - A device ID is never returned without a user ID or an application service
-        - The returned application service, if present, is permitted to control the
-          returned user ID.
+        - The `app_service` field in the returned `Requester` is set
+        - The `user_id` field in the returned `Requester` is either the application
+          service sender or the controlled user set by the `user_id` URI parameter
+        - The returned application service is permitted to control the returned user ID.
         - The returned device ID, if present, has been checked to be a valid device ID
           for the returned user ID.
         """
@@ -343,12 +312,12 @@ class Auth:
             self.get_access_token_from_request(request)
         )
         if app_service is None:
-            return None, None, None
+            return None
 
         if app_service.ip_range_whitelist:
             ip_address = IPAddress(request.getClientAddress().host)
             if ip_address not in app_service.ip_range_whitelist:
-                return None, None, None
+                return None
 
         # This will always be set by the time Twisted calls us.
         assert request.args is not None
@@ -382,13 +351,15 @@ class Auth:
                     Codes.EXCLUSIVE,
                 )
 
-        return effective_user_id, effective_device_id, app_service
+        return create_requester(
+            effective_user_id, app_service=app_service, device_id=effective_device_id
+        )
 
     async def get_user_by_access_token(
         self,
         token: str,
         allow_expired: bool = False,
-    ) -> TokenLookupResult:
+    ) -> Requester:
         """Validate access token and get user_id from it
 
         Args:
@@ -405,9 +376,9 @@ class Auth:
 
         # First look in the database to see if the access token is present
         # as an opaque token.
-        r = await self.store.get_user_by_access_token(token)
-        if r:
-            valid_until_ms = r.valid_until_ms
+        user_info = await self.store.get_user_by_access_token(token)
+        if user_info:
+            valid_until_ms = user_info.valid_until_ms
             if (
                 not allow_expired
                 and valid_until_ms is not None
@@ -419,7 +390,20 @@ class Auth:
                     msg="Access token has expired", soft_logout=True
                 )
 
-            return r
+            # Mark the token as used. This is used to invalidate old refresh
+            # tokens after some time.
+            await self.store.mark_access_token_as_used(user_info.token_id)
+
+            requester = create_requester(
+                user_id=user_info.user_id,
+                access_token_id=user_info.token_id,
+                is_guest=user_info.is_guest,
+                shadow_banned=user_info.shadow_banned,
+                device_id=user_info.device_id,
+                authenticated_entity=user_info.token_owner,
+            )
+
+            return requester
 
         # If the token isn't found in the database, then it could still be a
         # macaroon for a guest, so we check that here.
@@ -445,11 +429,12 @@ class Auth:
                     "Guest access token used for regular user"
                 )
 
-            return TokenLookupResult(
+            return create_requester(
                 user_id=user_id,
                 is_guest=True,
                 # all guests get the same device id
                 device_id=GUEST_DEVICE_ID,
+                authenticated_entity=user_id,
             )
         except (
             pymacaroons.exceptions.MacaroonException,
@@ -472,32 +457,33 @@ class Auth:
         request.requester = create_requester(service.sender, app_service=service)
         return service
 
-    async def is_server_admin(self, user: UserID) -> bool:
+    async def is_server_admin(self, requester: Requester) -> bool:
         """Check if the given user is a local server admin.
 
         Args:
-            user: user to check
+            requester: The user making the request, according to the access token.
 
         Returns:
             True if the user is an admin
         """
-        return await self.store.is_server_admin(user)
+        return await self.store.is_server_admin(requester.user)
 
-    async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
+    async def check_can_change_room_list(
+        self, room_id: str, requester: Requester
+    ) -> bool:
         """Determine whether the user is allowed to edit the room's entry in the
         published room list.
 
         Args:
-            room_id
-            user
+            room_id: The room to check.
+            requester: The user making the request, according to the access token.
         """
 
-        is_admin = await self.is_server_admin(user)
+        is_admin = await self.is_server_admin(requester)
         if is_admin:
             return True
 
-        user_id = user.to_string()
-        await self.check_user_in_room(room_id, user_id)
+        await self.check_user_in_room(room_id, requester)
 
         # We currently require the user is a "moderator" in the room. We do this
         # by checking if they would (theoretically) be able to change the
@@ -516,7 +502,9 @@ class Auth:
         send_level = event_auth.get_send_level(
             EventTypes.CanonicalAlias, "", power_level_event
         )
-        user_level = event_auth.get_user_power_level(user_id, auth_events)
+        user_level = event_auth.get_user_power_level(
+            requester.user.to_string(), auth_events
+        )
 
         return user_level >= send_level
 
@@ -574,16 +562,16 @@ class Auth:
 
     @trace
     async def check_user_in_room_or_world_readable(
-        self, room_id: str, user_id: str, allow_departed_users: bool = False
+        self, room_id: str, requester: Requester, allow_departed_users: bool = False
     ) -> Tuple[str, Optional[str]]:
         """Checks that the user is or was in the room or the room is world
         readable. If it isn't then an exception is raised.
 
         Args:
-            room_id: room to check
-            user_id: user to check
-            allow_departed_users: if True, accept users that were previously
-                members but have now departed
+            room_id: The room to check.
+            requester: The user making the request, according to the access token.
+            allow_departed_users: If True, accept users that were previously
+                members but have now departed.
 
         Returns:
             Resolves to the current membership of the user in the room and the
@@ -598,7 +586,7 @@ class Auth:
             #  * The user is a guest user, and has joined the room
             # else it will throw.
             return await self.check_user_in_room(
-                room_id, user_id, allow_departed_users=allow_departed_users
+                room_id, requester, allow_departed_users=allow_departed_users
             )
         except AuthError:
             visibility = await self._storage_controllers.state.get_current_state_event(
@@ -613,6 +601,6 @@ class Auth:
             raise UnstableSpecAuthError(
                 403,
                 "User %s not in room %s, and room previews are disabled"
-                % (user_id, room_id),
+                % (requester.user, room_id),
                 errcode=Codes.NOT_JOINED,
             )