summary refs log tree commit diff
path: root/synapse/api/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/api/auth.py')
-rw-r--r--synapse/api/auth.py159
1 files changed, 73 insertions, 86 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 79e2808dc5..86f145649c 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -36,8 +36,11 @@ logger = logging.getLogger(__name__)
 
 
 AuthEventTypes = (
-    EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
-    EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
+    EventTypes.Create,
+    EventTypes.Member,
+    EventTypes.PowerLevels,
+    EventTypes.JoinRules,
+    EventTypes.RoomHistoryVisibility,
     EventTypes.ThirdPartyInvite,
 )
 
@@ -54,6 +57,7 @@ class Auth(object):
     FIXME: This class contains a mix of functions for authenticating users
     of our client-server API and authenticating events added to room graphs.
     """
+
     def __init__(self, hs):
         self.hs = hs
         self.clock = hs.get_clock()
@@ -70,15 +74,12 @@ class Auth(object):
     def check_from_context(self, room_version, event, context, do_sig_check=True):
         prev_state_ids = yield context.get_prev_state_ids(self.store)
         auth_events_ids = yield self.compute_auth_events(
-            event, prev_state_ids, for_verification=True,
+            event, prev_state_ids, for_verification=True
         )
         auth_events = yield self.store.get_events(auth_events_ids)
-        auth_events = {
-            (e.type, e.state_key): e for e in itervalues(auth_events)
-        }
+        auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)}
         self.check(
-            room_version, event,
-            auth_events=auth_events, do_sig_check=do_sig_check,
+            room_version, event, auth_events=auth_events, do_sig_check=do_sig_check
         )
 
     def check(self, room_version, event, auth_events, do_sig_check=True):
@@ -115,15 +116,10 @@ class Auth(object):
             the room.
         """
         if current_state:
-            member = current_state.get(
-                (EventTypes.Member, user_id),
-                None
-            )
+            member = current_state.get((EventTypes.Member, user_id), None)
         else:
             member = yield self.state.get_current_state(
-                room_id=room_id,
-                event_type=EventTypes.Member,
-                state_key=user_id
+                room_id=room_id, event_type=EventTypes.Member, state_key=user_id
             )
 
         self._check_joined_room(member, user_id, room_id)
@@ -143,23 +139,17 @@ class Auth(object):
             the room. This will be the leave event if they have left the room.
         """
         member = yield self.state.get_current_state(
-            room_id=room_id,
-            event_type=EventTypes.Member,
-            state_key=user_id
+            room_id=room_id, event_type=EventTypes.Member, state_key=user_id
         )
         membership = member.membership if member else None
 
         if membership not in (Membership.JOIN, Membership.LEAVE):
-            raise AuthError(403, "User %s not in room %s" % (
-                user_id, room_id
-            ))
+            raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
 
         if membership == Membership.LEAVE:
             forgot = yield self.store.did_forget(user_id, room_id)
             if forgot:
-                raise AuthError(403, "User %s not in room %s" % (
-                    user_id, room_id
-                ))
+                raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
 
         defer.returnValue(member)
 
@@ -171,9 +161,9 @@ class Auth(object):
 
     def _check_joined_room(self, member, user_id, room_id):
         if not member or member.membership != Membership.JOIN:
-            raise AuthError(403, "User %s not in room %s (%s)" % (
-                user_id, room_id, repr(member)
-            ))
+            raise AuthError(
+                403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
+            )
 
     def can_federate(self, event, auth_events):
         creation_event = auth_events.get((EventTypes.Create, ""))
@@ -185,11 +175,7 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def get_user_by_req(
-        self,
-        request,
-        allow_guest=False,
-        rights="access",
-        allow_expired=False,
+        self, request, allow_guest=False, rights="access", allow_expired=False
     ):
         """ Get a registered user's ID.
 
@@ -209,9 +195,8 @@ class Auth(object):
         try:
             ip_addr = self.hs.get_ip_from_request(request)
             user_agent = request.requestHeaders.getRawHeaders(
-                b"User-Agent",
-                default=[b""]
-            )[0].decode('ascii', 'surrogateescape')
+                b"User-Agent", default=[b""]
+            )[0].decode("ascii", "surrogateescape")
 
             access_token = self.get_access_token_from_request(
                 request, self.TOKEN_NOT_FOUND_HTTP_STATUS
@@ -243,11 +228,12 @@ class Auth(object):
             if self._account_validity.enabled and not allow_expired:
                 user_id = user.to_string()
                 expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
-                if expiration_ts is not None and self.clock.time_msec() >= expiration_ts:
+                if (
+                    expiration_ts is not None
+                    and self.clock.time_msec() >= expiration_ts
+                ):
                     raise AuthError(
-                        403,
-                        "User account has expired",
-                        errcode=Codes.EXPIRED_ACCOUNT,
+                        403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
                     )
 
             # device_id may not be present if get_user_by_access_token has been
@@ -265,18 +251,23 @@ class Auth(object):
 
             if is_guest and not allow_guest:
                 raise AuthError(
-                    403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+                    403,
+                    "Guest access not allowed",
+                    errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                 )
 
             request.authenticated_entity = user.to_string()
 
-            defer.returnValue(synapse.types.create_requester(
-                user, token_id, is_guest, device_id, app_service=app_service)
+            defer.returnValue(
+                synapse.types.create_requester(
+                    user, token_id, is_guest, device_id, app_service=app_service
+                )
             )
         except KeyError:
             raise AuthError(
-                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
-                errcode=Codes.MISSING_TOKEN
+                self.TOKEN_NOT_FOUND_HTTP_STATUS,
+                "Missing access token.",
+                errcode=Codes.MISSING_TOKEN,
             )
 
     @defer.inlineCallbacks
@@ -297,20 +288,14 @@ class Auth(object):
         if b"user_id" not in request.args:
             defer.returnValue((app_service.sender, app_service))
 
-        user_id = request.args[b"user_id"][0].decode('utf8')
+        user_id = request.args[b"user_id"][0].decode("utf8")
         if app_service.sender == user_id:
             defer.returnValue((app_service.sender, app_service))
 
         if not app_service.is_interested_in_user(user_id):
-            raise AuthError(
-                403,
-                "Application service cannot masquerade as this user."
-            )
+            raise AuthError(403, "Application service cannot masquerade as this user.")
         if not (yield self.store.get_user_by_id(user_id)):
-            raise AuthError(
-                403,
-                "Application service has not registered this user"
-            )
+            raise AuthError(403, "Application service has not registered this user")
         defer.returnValue((user_id, app_service))
 
     @defer.inlineCallbacks
@@ -368,13 +353,13 @@ class Auth(object):
                     raise AuthError(
                         self.TOKEN_NOT_FOUND_HTTP_STATUS,
                         "Unknown user_id %s" % user_id,
-                        errcode=Codes.UNKNOWN_TOKEN
+                        errcode=Codes.UNKNOWN_TOKEN,
                     )
                 if not stored_user["is_guest"]:
                     raise AuthError(
                         self.TOKEN_NOT_FOUND_HTTP_STATUS,
                         "Guest access token used for regular user",
-                        errcode=Codes.UNKNOWN_TOKEN
+                        errcode=Codes.UNKNOWN_TOKEN,
                     )
                 ret = {
                     "user": user,
@@ -402,8 +387,9 @@ class Auth(object):
         ) as e:
             logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
             raise AuthError(
-                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
-                errcode=Codes.UNKNOWN_TOKEN
+                self.TOKEN_NOT_FOUND_HTTP_STATUS,
+                "Invalid macaroon passed.",
+                errcode=Codes.UNKNOWN_TOKEN,
             )
 
     def _parse_and_validate_macaroon(self, token, rights="access"):
@@ -441,13 +427,13 @@ class Auth(object):
                     guest = True
 
             self.validate_macaroon(
-                macaroon, rights, self.hs.config.expire_access_token,
-                user_id=user_id,
+                macaroon, rights, self.hs.config.expire_access_token, user_id=user_id
             )
         except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
             raise AuthError(
-                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
-                errcode=Codes.UNKNOWN_TOKEN
+                self.TOKEN_NOT_FOUND_HTTP_STATUS,
+                "Invalid macaroon passed.",
+                errcode=Codes.UNKNOWN_TOKEN,
             )
 
         if not has_expiry and rights == "access":
@@ -472,10 +458,11 @@ class Auth(object):
         user_prefix = "user_id = "
         for caveat in macaroon.caveats:
             if caveat.caveat_id.startswith(user_prefix):
-                return caveat.caveat_id[len(user_prefix):]
+                return caveat.caveat_id[len(user_prefix) :]
         raise AuthError(
-            self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
-            errcode=Codes.UNKNOWN_TOKEN
+            self.TOKEN_NOT_FOUND_HTTP_STATUS,
+            "No user caveat in macaroon",
+            errcode=Codes.UNKNOWN_TOKEN,
         )
 
     def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
@@ -522,7 +509,7 @@ class Auth(object):
         prefix = "time < "
         if not caveat.startswith(prefix):
             return False
-        expiry = int(caveat[len(prefix):])
+        expiry = int(caveat[len(prefix) :])
         now = self.hs.get_clock().time_msec()
         return now < expiry
 
@@ -554,14 +541,12 @@ class Auth(object):
                 raise AuthError(
                     self.TOKEN_NOT_FOUND_HTTP_STATUS,
                     "Unrecognised access token.",
-                    errcode=Codes.UNKNOWN_TOKEN
+                    errcode=Codes.UNKNOWN_TOKEN,
                 )
             request.authenticated_entity = service.sender
             return defer.succeed(service)
         except KeyError:
-            raise AuthError(
-                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
-            )
+            raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.")
 
     def is_server_admin(self, user):
         """ Check if the given user is a local server admin.
@@ -581,19 +566,19 @@ class Auth(object):
 
         auth_ids = []
 
-        key = (EventTypes.PowerLevels, "", )
+        key = (EventTypes.PowerLevels, "")
         power_level_event_id = current_state_ids.get(key)
 
         if power_level_event_id:
             auth_ids.append(power_level_event_id)
 
-        key = (EventTypes.JoinRules, "", )
+        key = (EventTypes.JoinRules, "")
         join_rule_event_id = current_state_ids.get(key)
 
-        key = (EventTypes.Member, event.sender, )
+        key = (EventTypes.Member, event.sender)
         member_event_id = current_state_ids.get(key)
 
-        key = (EventTypes.Create, "", )
+        key = (EventTypes.Create, "")
         create_event_id = current_state_ids.get(key)
         if create_event_id:
             auth_ids.append(create_event_id)
@@ -619,7 +604,7 @@ class Auth(object):
                     auth_ids.append(member_event_id)
 
                 if for_verification:
-                    key = (EventTypes.Member, event.state_key, )
+                    key = (EventTypes.Member, event.state_key)
                     existing_event_id = current_state_ids.get(key)
                     if existing_event_id:
                         auth_ids.append(existing_event_id)
@@ -628,7 +613,7 @@ class Auth(object):
                 if "third_party_invite" in event.content:
                     key = (
                         EventTypes.ThirdPartyInvite,
-                        event.content["third_party_invite"]["signed"]["token"]
+                        event.content["third_party_invite"]["signed"]["token"],
                     )
                     third_party_invite_id = current_state_ids.get(key)
                     if third_party_invite_id:
@@ -684,7 +669,7 @@ class Auth(object):
             auth_events[(EventTypes.PowerLevels, "")] = power_level_event
 
         send_level = event_auth.get_send_level(
-            EventTypes.Aliases, "", power_level_event,
+            EventTypes.Aliases, "", power_level_event
         )
         user_level = event_auth.get_user_power_level(user_id, auth_events)
 
@@ -692,7 +677,7 @@ class Auth(object):
             raise AuthError(
                 403,
                 "This server requires you to be a moderator in the room to"
-                " edit its room list entry"
+                " edit its room list entry",
             )
 
     @staticmethod
@@ -742,7 +727,7 @@ class Auth(object):
                 )
             parts = auth_headers[0].split(b" ")
             if parts[0] == b"Bearer" and len(parts) == 2:
-                return parts[1].decode('ascii')
+                return parts[1].decode("ascii")
             else:
                 raise AuthError(
                     token_not_found_http_status,
@@ -755,10 +740,10 @@ class Auth(object):
                 raise AuthError(
                     token_not_found_http_status,
                     "Missing access token.",
-                    errcode=Codes.MISSING_TOKEN
+                    errcode=Codes.MISSING_TOKEN,
                 )
 
-            return query_params[0].decode('ascii')
+            return query_params[0].decode("ascii")
 
     @defer.inlineCallbacks
     def check_in_room_or_world_readable(self, room_id, user_id):
@@ -785,8 +770,8 @@ class Auth(object):
                 room_id, EventTypes.RoomHistoryVisibility, ""
             )
             if (
-                visibility and
-                visibility.content["history_visibility"] == "world_readable"
+                visibility
+                and visibility.content["history_visibility"] == "world_readable"
             ):
                 defer.returnValue((Membership.JOIN, None))
                 return
@@ -820,10 +805,11 @@ class Auth(object):
 
         if self.hs.config.hs_disabled:
             raise ResourceLimitError(
-                403, self.hs.config.hs_disabled_message,
+                403,
+                self.hs.config.hs_disabled_message,
                 errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
                 admin_contact=self.hs.config.admin_contact,
-                limit_type=self.hs.config.hs_disabled_limit_type
+                limit_type=self.hs.config.hs_disabled_limit_type,
             )
         if self.hs.config.limit_usage_by_mau is True:
             assert not (user_id and threepid)
@@ -848,8 +834,9 @@ class Auth(object):
             current_mau = yield self.store.get_monthly_active_count()
             if current_mau >= self.hs.config.max_mau_value:
                 raise ResourceLimitError(
-                    403, "Monthly Active User Limit Exceeded",
+                    403,
+                    "Monthly Active User Limit Exceeded",
                     admin_contact=self.hs.config.admin_contact,
                     errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
-                    limit_type="monthly_active_user"
+                    limit_type="monthly_active_user",
                 )