summary refs log tree commit diff
path: root/synapse/api/auth.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-06 08:30:06 -0400
committerGitHub <noreply@github.com>2020-08-06 08:30:06 -0400
commitd4a7829b12197faf52eb487c443ee09acafeb37e (patch)
treecab5f15a7532596153f61b47aafcf4cb4a4b7d45 /synapse/api/auth.py
parentConvert run_as_background_process inner function to async. (#8032) (diff)
downloadsynapse-d4a7829b12197faf52eb487c443ee09acafeb37e.tar.xz
Convert synapse.api to async/await (#8031)
Diffstat (limited to 'synapse/api/auth.py')
-rw-r--r--synapse/api/auth.py123
1 files changed, 56 insertions, 67 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 2178e623da..d8190f92ab 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -13,12 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Optional
+from typing import List, Optional, Tuple
 
 import pymacaroons
 from netaddr import IPAddress
 
-from twisted.internet import defer
 from twisted.web.server import Request
 
 import synapse.types
@@ -80,13 +79,14 @@ class Auth(object):
         self._track_appservice_user_ips = hs.config.track_appservice_user_ips
         self._macaroon_secret_key = hs.config.macaroon_secret_key
 
-    @defer.inlineCallbacks
-    def check_from_context(self, room_version: str, event, context, do_sig_check=True):
-        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
-        auth_events_ids = yield self.compute_auth_events(
+    async def check_from_context(
+        self, room_version: str, event, context, do_sig_check=True
+    ):
+        prev_state_ids = await context.get_prev_state_ids()
+        auth_events_ids = self.compute_auth_events(
             event, prev_state_ids, for_verification=True
         )
-        auth_events = yield self.store.get_events(auth_events_ids)
+        auth_events = await self.store.get_events(auth_events_ids)
         auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
 
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -94,14 +94,13 @@ class Auth(object):
             room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
         )
 
-    @defer.inlineCallbacks
-    def check_user_in_room(
+    async def check_user_in_room(
         self,
         room_id: str,
         user_id: str,
         current_state: Optional[StateMap[EventBase]] = None,
         allow_departed_users: bool = False,
-    ):
+    ) -> EventBase:
         """Check if the user is in the room, or was at some point.
         Args:
             room_id: The room to check.
@@ -119,37 +118,35 @@ class Auth(object):
         Raises:
             AuthError if the user is/was not in the room.
         Returns:
-            Deferred[Optional[EventBase]]:
-                Membership event for the user if the user was in the
-                room. This will be the join event if they are currently joined to
-                the room. This will be the leave event if they have left the room.
+            Membership event for the user if the user was in the
+            room. This will be the join event if they are currently joined to
+            the room. This will be the leave event if they have left the room.
         """
         if current_state:
             member = current_state.get((EventTypes.Member, user_id), None)
         else:
-            member = yield defer.ensureDeferred(
-                self.state.get_current_state(
-                    room_id=room_id, event_type=EventTypes.Member, state_key=user_id
-                )
+            member = await self.state.get_current_state(
+                room_id=room_id, event_type=EventTypes.Member, state_key=user_id
             )
-        membership = member.membership if member else None
 
-        if membership == Membership.JOIN:
-            return member
+        if member:
+            membership = member.membership
 
-        # XXX this looks totally bogus. Why do we not allow users who have been banned,
-        # or those who were members previously and have been re-invited?
-        if allow_departed_users and membership == Membership.LEAVE:
-            forgot = yield self.store.did_forget(user_id, room_id)
-            if not forgot:
+            if membership == Membership.JOIN:
                 return member
 
+            # XXX this looks totally bogus. Why do we not allow users who have been banned,
+            # or those who were members previously and have been re-invited?
+            if allow_departed_users and membership == Membership.LEAVE:
+                forgot = await self.store.did_forget(user_id, room_id)
+                if not forgot:
+                    return member
+
         raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
 
-    @defer.inlineCallbacks
-    def check_host_in_room(self, room_id, host):
+    async def check_host_in_room(self, room_id, host):
         with Measure(self.clock, "check_host_in_room"):
-            latest_event_ids = yield self.store.is_host_joined(room_id, host)
+            latest_event_ids = await self.store.is_host_joined(room_id, host)
             return latest_event_ids
 
     def can_federate(self, event, auth_events):
@@ -160,14 +157,13 @@ class Auth(object):
     def get_public_keys(self, invite_event):
         return event_auth.get_public_keys(invite_event)
 
-    @defer.inlineCallbacks
-    def get_user_by_req(
+    async def get_user_by_req(
         self,
         request: Request,
         allow_guest: bool = False,
         rights: str = "access",
         allow_expired: bool = False,
-    ):
+    ) -> synapse.types.Requester:
         """ Get a registered user's ID.
 
         Args:
@@ -180,7 +176,7 @@ class Auth(object):
                 /login will deliver access tokens regardless of expiration.
 
         Returns:
-            defer.Deferred: resolves to a `synapse.types.Requester` object
+            Resolves to the requester
         Raises:
             InvalidClientCredentialsError if no user by that token exists or the token
                 is invalid.
@@ -194,14 +190,14 @@ class Auth(object):
 
             access_token = self.get_access_token_from_request(request)
 
-            user_id, app_service = yield self._get_appservice_user_id(request)
+            user_id, app_service = await self._get_appservice_user_id(request)
             if user_id:
                 request.authenticated_entity = user_id
                 opentracing.set_tag("authenticated_entity", user_id)
                 opentracing.set_tag("appservice_id", app_service.id)
 
                 if ip_addr and self._track_appservice_user_ips:
-                    yield self.store.insert_client_ip(
+                    await self.store.insert_client_ip(
                         user_id=user_id,
                         access_token=access_token,
                         ip=ip_addr,
@@ -211,7 +207,7 @@ class Auth(object):
 
                 return synapse.types.create_requester(user_id, app_service=app_service)
 
-            user_info = yield self.get_user_by_access_token(
+            user_info = await self.get_user_by_access_token(
                 access_token, rights, allow_expired=allow_expired
             )
             user = user_info["user"]
@@ -221,7 +217,7 @@ class Auth(object):
             # Deny the request if the user account has expired.
             if self._account_validity.enabled and not allow_expired:
                 user_id = user.to_string()
-                expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
+                expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
                 if (
                     expiration_ts is not None
                     and self.clock.time_msec() >= expiration_ts
@@ -235,7 +231,7 @@ class Auth(object):
             device_id = user_info.get("device_id")
 
             if user and access_token and ip_addr:
-                yield self.store.insert_client_ip(
+                await self.store.insert_client_ip(
                     user_id=user.to_string(),
                     access_token=access_token,
                     ip=ip_addr,
@@ -261,8 +257,7 @@ class Auth(object):
         except KeyError:
             raise MissingClientTokenError()
 
-    @defer.inlineCallbacks
-    def _get_appservice_user_id(self, request):
+    async def _get_appservice_user_id(self, request):
         app_service = self.store.get_app_service_by_token(
             self.get_access_token_from_request(request)
         )
@@ -283,14 +278,13 @@ class Auth(object):
 
         if not app_service.is_interested_in_user(user_id):
             raise AuthError(403, "Application service cannot masquerade as this user.")
-        if not (yield self.store.get_user_by_id(user_id)):
+        if not (await self.store.get_user_by_id(user_id)):
             raise AuthError(403, "Application service has not registered this user")
         return user_id, app_service
 
-    @defer.inlineCallbacks
-    def get_user_by_access_token(
+    async def get_user_by_access_token(
         self, token: str, rights: str = "access", allow_expired: bool = False,
-    ):
+    ) -> dict:
         """ Validate access token and get user_id from it
 
         Args:
@@ -300,7 +294,7 @@ class Auth(object):
             allow_expired: If False, raises an InvalidClientTokenError
                 if the token is expired
         Returns:
-            Deferred[dict]: dict that includes:
+            dict that includes:
                `user` (UserID)
                `is_guest` (bool)
                `token_id` (int|None): access token id. May be None if guest
@@ -314,7 +308,7 @@ class Auth(object):
 
         if rights == "access":
             # first look in the database
-            r = yield self._look_up_user_by_access_token(token)
+            r = await self._look_up_user_by_access_token(token)
             if r:
                 valid_until_ms = r["valid_until_ms"]
                 if (
@@ -352,7 +346,7 @@ class Auth(object):
                 # It would of course be much easier to store guest access
                 # tokens in the database as well, but that would break existing
                 # guest tokens.
-                stored_user = yield self.store.get_user_by_id(user_id)
+                stored_user = await self.store.get_user_by_id(user_id)
                 if not stored_user:
                     raise InvalidClientTokenError("Unknown user_id %s" % user_id)
                 if not stored_user["is_guest"]:
@@ -482,9 +476,8 @@ class Auth(object):
         now = self.hs.get_clock().time_msec()
         return now < expiry
 
-    @defer.inlineCallbacks
-    def _look_up_user_by_access_token(self, token):
-        ret = yield self.store.get_user_by_access_token(token)
+    async def _look_up_user_by_access_token(self, token):
+        ret = await self.store.get_user_by_access_token(token)
         if not ret:
             return None
 
@@ -507,7 +500,7 @@ class Auth(object):
             logger.warning("Unrecognised appservice access token.")
             raise InvalidClientTokenError()
         request.authenticated_entity = service.sender
-        return defer.succeed(service)
+        return service
 
     async def is_server_admin(self, user: UserID) -> bool:
         """ Check if the given user is a local server admin.
@@ -522,7 +515,7 @@ class Auth(object):
 
     def compute_auth_events(
         self, event, current_state_ids: StateMap[str], for_verification: bool = False,
-    ):
+    ) -> List[str]:
         """Given an event and current state return the list of event IDs used
         to auth an event.
 
@@ -530,11 +523,11 @@ class Auth(object):
         should be added to the event's `auth_events`.
 
         Returns:
-            defer.Deferred(list[str]): List of event IDs.
+            List of event IDs.
         """
 
         if event.type == EventTypes.Create:
-            return defer.succeed([])
+            return []
 
         # Currently we ignore the `for_verification` flag even though there are
         # some situations where we can drop particular auth events when adding
@@ -553,7 +546,7 @@ class Auth(object):
             if auth_ev_id:
                 auth_ids.append(auth_ev_id)
 
-        return defer.succeed(auth_ids)
+        return auth_ids
 
     async def check_can_change_room_list(self, room_id: str, user: UserID):
         """Determine whether the user is allowed to edit the room's entry in the
@@ -636,10 +629,9 @@ class Auth(object):
 
             return query_params[0].decode("ascii")
 
-    @defer.inlineCallbacks
-    def check_user_in_room_or_world_readable(
+    async def check_user_in_room_or_world_readable(
         self, room_id: str, user_id: str, allow_departed_users: bool = False
-    ):
+    ) -> Tuple[str, Optional[str]]:
         """Checks that the user is or was in the room or the room is world
         readable. If it isn't then an exception is raised.
 
@@ -650,10 +642,9 @@ class Auth(object):
                 members but have now departed
 
         Returns:
-            Deferred[tuple[str, str|None]]: Resolves to the current membership of
-                the user in the room and the membership event ID of the user. If
-                the user is not in the room and never has been, then
-                `(Membership.JOIN, None)` is returned.
+            Resolves to the current membership of the user in the room and the
+            membership event ID of the user. If the user is not in the room and
+            never has been, then `(Membership.JOIN, None)` is returned.
         """
 
         try:
@@ -662,15 +653,13 @@ class Auth(object):
             #  * The user is a non-guest user, and was ever in the room
             #  * The user is a guest user, and has joined the room
             # else it will throw.
-            member_event = yield self.check_user_in_room(
+            member_event = await self.check_user_in_room(
                 room_id, user_id, allow_departed_users=allow_departed_users
             )
             return member_event.membership, member_event.event_id
         except AuthError:
-            visibility = yield defer.ensureDeferred(
-                self.state.get_current_state(
-                    room_id, EventTypes.RoomHistoryVisibility, ""
-                )
+            visibility = await self.state.get_current_state(
+                room_id, EventTypes.RoomHistoryVisibility, ""
             )
             if (
                 visibility