diff options
Diffstat (limited to 'synapse/api/auth.py')
-rw-r--r-- | synapse/api/auth.py | 159 |
1 files changed, 73 insertions, 86 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 79e2808dc5..86f145649c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -36,8 +36,11 @@ logger = logging.getLogger(__name__) AuthEventTypes = ( - EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels, - EventTypes.JoinRules, EventTypes.RoomHistoryVisibility, + EventTypes.Create, + EventTypes.Member, + EventTypes.PowerLevels, + EventTypes.JoinRules, + EventTypes.RoomHistoryVisibility, EventTypes.ThirdPartyInvite, ) @@ -54,6 +57,7 @@ class Auth(object): FIXME: This class contains a mix of functions for authenticating users of our client-server API and authenticating events added to room graphs. """ + def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() @@ -70,15 +74,12 @@ class Auth(object): def check_from_context(self, room_version, event, context, do_sig_check=True): prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in itervalues(auth_events) - } + auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} self.check( - room_version, event, - auth_events=auth_events, do_sig_check=do_sig_check, + room_version, event, auth_events=auth_events, do_sig_check=do_sig_check ) def check(self, room_version, event, auth_events, do_sig_check=True): @@ -115,15 +116,10 @@ class Auth(object): the room. """ if current_state: - member = current_state.get( - (EventTypes.Member, user_id), - None - ) + member = current_state.get((EventTypes.Member, user_id), None) else: member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) self._check_joined_room(member, user_id, room_id) @@ -143,23 +139,17 @@ class Auth(object): the room. This will be the leave event if they have left the room. """ member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None if membership not in (Membership.JOIN, Membership.LEAVE): - raise AuthError(403, "User %s not in room %s" % ( - user_id, room_id - )) + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) if membership == Membership.LEAVE: forgot = yield self.store.did_forget(user_id, room_id) if forgot: - raise AuthError(403, "User %s not in room %s" % ( - user_id, room_id - )) + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) defer.returnValue(member) @@ -171,9 +161,9 @@ class Auth(object): def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s (%s)" % ( - user_id, room_id, repr(member) - )) + raise AuthError( + 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) + ) def can_federate(self, event, auth_events): creation_event = auth_events.get((EventTypes.Create, "")) @@ -185,11 +175,7 @@ class Auth(object): @defer.inlineCallbacks def get_user_by_req( - self, - request, - allow_guest=False, - rights="access", - allow_expired=False, + self, request, allow_guest=False, rights="access", allow_expired=False ): """ Get a registered user's ID. @@ -209,9 +195,8 @@ class Auth(object): try: ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( - b"User-Agent", - default=[b""] - )[0].decode('ascii', 'surrogateescape') + b"User-Agent", default=[b""] + )[0].decode("ascii", "surrogateescape") access_token = self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS @@ -243,11 +228,12 @@ class Auth(object): if self._account_validity.enabled and not allow_expired: user_id = user.to_string() expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - if expiration_ts is not None and self.clock.time_msec() >= expiration_ts: + if ( + expiration_ts is not None + and self.clock.time_msec() >= expiration_ts + ): raise AuthError( - 403, - "User account has expired", - errcode=Codes.EXPIRED_ACCOUNT, + 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) # device_id may not be present if get_user_by_access_token has been @@ -265,18 +251,23 @@ class Auth(object): if is_guest and not allow_guest: raise AuthError( - 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + 403, + "Guest access not allowed", + errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) request.authenticated_entity = user.to_string() - defer.returnValue(synapse.types.create_requester( - user, token_id, is_guest, device_id, app_service=app_service) + defer.returnValue( + synapse.types.create_requester( + user, token_id, is_guest, device_id, app_service=app_service + ) ) except KeyError: raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", - errcode=Codes.MISSING_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Missing access token.", + errcode=Codes.MISSING_TOKEN, ) @defer.inlineCallbacks @@ -297,20 +288,14 @@ class Auth(object): if b"user_id" not in request.args: defer.returnValue((app_service.sender, app_service)) - user_id = request.args[b"user_id"][0].decode('utf8') + user_id = request.args[b"user_id"][0].decode("utf8") if app_service.sender == user_id: defer.returnValue((app_service.sender, app_service)) if not app_service.is_interested_in_user(user_id): - raise AuthError( - 403, - "Application service cannot masquerade as this user." - ) + raise AuthError(403, "Application service cannot masquerade as this user.") if not (yield self.store.get_user_by_id(user_id)): - raise AuthError( - 403, - "Application service has not registered this user" - ) + raise AuthError(403, "Application service has not registered this user") defer.returnValue((user_id, app_service)) @defer.inlineCallbacks @@ -368,13 +353,13 @@ class Auth(object): raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unknown user_id %s" % user_id, - errcode=Codes.UNKNOWN_TOKEN + errcode=Codes.UNKNOWN_TOKEN, ) if not stored_user["is_guest"]: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Guest access token used for regular user", - errcode=Codes.UNKNOWN_TOKEN + errcode=Codes.UNKNOWN_TOKEN, ) ret = { "user": user, @@ -402,8 +387,9 @@ class Auth(object): ) as e: logger.warning("Invalid macaroon in auth: %s %s", type(e), e) raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN, ) def _parse_and_validate_macaroon(self, token, rights="access"): @@ -441,13 +427,13 @@ class Auth(object): guest = True self.validate_macaroon( - macaroon, rights, self.hs.config.expire_access_token, - user_id=user_id, + macaroon, rights, self.hs.config.expire_access_token, user_id=user_id ) except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN, ) if not has_expiry and rights == "access": @@ -472,10 +458,11 @@ class Auth(object): user_prefix = "user_id = " for caveat in macaroon.caveats: if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix):] + return caveat.caveat_id[len(user_prefix) :] raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN, ) def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): @@ -522,7 +509,7 @@ class Auth(object): prefix = "time < " if not caveat.startswith(prefix): return False - expiry = int(caveat[len(prefix):]) + expiry = int(caveat[len(prefix) :]) now = self.hs.get_clock().time_msec() return now < expiry @@ -554,14 +541,12 @@ class Auth(object): raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", - errcode=Codes.UNKNOWN_TOKEN + errcode=Codes.UNKNOWN_TOKEN, ) request.authenticated_entity = service.sender return defer.succeed(service) except KeyError: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." - ) + raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.") def is_server_admin(self, user): """ Check if the given user is a local server admin. @@ -581,19 +566,19 @@ class Auth(object): auth_ids = [] - key = (EventTypes.PowerLevels, "", ) + key = (EventTypes.PowerLevels, "") power_level_event_id = current_state_ids.get(key) if power_level_event_id: auth_ids.append(power_level_event_id) - key = (EventTypes.JoinRules, "", ) + key = (EventTypes.JoinRules, "") join_rule_event_id = current_state_ids.get(key) - key = (EventTypes.Member, event.sender, ) + key = (EventTypes.Member, event.sender) member_event_id = current_state_ids.get(key) - key = (EventTypes.Create, "", ) + key = (EventTypes.Create, "") create_event_id = current_state_ids.get(key) if create_event_id: auth_ids.append(create_event_id) @@ -619,7 +604,7 @@ class Auth(object): auth_ids.append(member_event_id) if for_verification: - key = (EventTypes.Member, event.state_key, ) + key = (EventTypes.Member, event.state_key) existing_event_id = current_state_ids.get(key) if existing_event_id: auth_ids.append(existing_event_id) @@ -628,7 +613,7 @@ class Auth(object): if "third_party_invite" in event.content: key = ( EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"] + event.content["third_party_invite"]["signed"]["token"], ) third_party_invite_id = current_state_ids.get(key) if third_party_invite_id: @@ -684,7 +669,7 @@ class Auth(object): auth_events[(EventTypes.PowerLevels, "")] = power_level_event send_level = event_auth.get_send_level( - EventTypes.Aliases, "", power_level_event, + EventTypes.Aliases, "", power_level_event ) user_level = event_auth.get_user_power_level(user_id, auth_events) @@ -692,7 +677,7 @@ class Auth(object): raise AuthError( 403, "This server requires you to be a moderator in the room to" - " edit its room list entry" + " edit its room list entry", ) @staticmethod @@ -742,7 +727,7 @@ class Auth(object): ) parts = auth_headers[0].split(b" ") if parts[0] == b"Bearer" and len(parts) == 2: - return parts[1].decode('ascii') + return parts[1].decode("ascii") else: raise AuthError( token_not_found_http_status, @@ -755,10 +740,10 @@ class Auth(object): raise AuthError( token_not_found_http_status, "Missing access token.", - errcode=Codes.MISSING_TOKEN + errcode=Codes.MISSING_TOKEN, ) - return query_params[0].decode('ascii') + return query_params[0].decode("ascii") @defer.inlineCallbacks def check_in_room_or_world_readable(self, room_id, user_id): @@ -785,8 +770,8 @@ class Auth(object): room_id, EventTypes.RoomHistoryVisibility, "" ) if ( - visibility and - visibility.content["history_visibility"] == "world_readable" + visibility + and visibility.content["history_visibility"] == "world_readable" ): defer.returnValue((Membership.JOIN, None)) return @@ -820,10 +805,11 @@ class Auth(object): if self.hs.config.hs_disabled: raise ResourceLimitError( - 403, self.hs.config.hs_disabled_message, + 403, + self.hs.config.hs_disabled_message, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_contact=self.hs.config.admin_contact, - limit_type=self.hs.config.hs_disabled_limit_type + limit_type=self.hs.config.hs_disabled_limit_type, ) if self.hs.config.limit_usage_by_mau is True: assert not (user_id and threepid) @@ -848,8 +834,9 @@ class Auth(object): current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise ResourceLimitError( - 403, "Monthly Active User Limit Exceeded", + 403, + "Monthly Active User Limit Exceeded", admin_contact=self.hs.config.admin_contact, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - limit_type="monthly_active_user" + limit_type="monthly_active_user", ) |