You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
261 lines
6.8 KiB
261 lines
6.8 KiB
package cognitoauth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"git.aiterp.net/stufflog3/stufflog3/entities"
|
|
"git.aiterp.net/stufflog3/stufflog3/models"
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/cognitoidentityprovider"
|
|
"github.com/dgrijalva/jwt-go/v4"
|
|
"github.com/lestrrat-go/jwx/jwa"
|
|
"github.com/lestrrat-go/jwx/jwk"
|
|
"strings"
|
|
)
|
|
|
|
type Client struct {
|
|
poolID string
|
|
poolClientID string
|
|
poolClientSecret string
|
|
session *session.Session
|
|
keySet jwk.Set
|
|
}
|
|
|
|
func (c *Client) ListUsers(ctx context.Context) ([]entities.User, error) {
|
|
cognitoClient := cognitoidentityprovider.New(c.session)
|
|
res, err := cognitoClient.ListUsersWithContext(ctx, &cognitoidentityprovider.ListUsersInput{
|
|
UserPoolId: aws.String(c.poolID),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
users := make([]entities.User, 0, 16)
|
|
for _, u := range res.Users {
|
|
user := entities.User{}
|
|
for _, attr := range u.Attributes {
|
|
switch *attr.Name {
|
|
case "sub":
|
|
user.ID = *attr.Value
|
|
}
|
|
}
|
|
|
|
users = append(users, user)
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
func (c *Client) RefreshUser(ctx context.Context, token, refreshToken string) (*entities.AuthResult, error) {
|
|
user := c.parseToken(token)
|
|
if user == nil {
|
|
return nil, models.PermissionDeniedError{}
|
|
}
|
|
|
|
mac := hmac.New(sha256.New, []byte(c.poolClientSecret))
|
|
mac.Write([]byte(user.Name + c.poolClientID))
|
|
secretHash := base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
|
|
|
cognitoClient := cognitoidentityprovider.New(c.session)
|
|
|
|
res, err := cognitoClient.InitiateAuthWithContext(ctx, &cognitoidentityprovider.InitiateAuthInput{
|
|
AuthFlow: aws.String("REFRESH_TOKEN_AUTH"),
|
|
AuthParameters: map[string]*string{
|
|
"REFRESH_TOKEN": &refreshToken,
|
|
"SECRET_HASH": &secretHash,
|
|
},
|
|
ClientId: aws.String(c.poolClientID),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if res.AuthenticationResult == nil || res.AuthenticationResult.IdToken == nil {
|
|
return nil, models.PermissionDeniedError{}
|
|
}
|
|
|
|
idToken := *res.AuthenticationResult.IdToken
|
|
|
|
if res.AuthenticationResult.RefreshToken != nil {
|
|
refreshToken = *res.AuthenticationResult.RefreshToken
|
|
} else {
|
|
refreshToken = ""
|
|
}
|
|
|
|
return &entities.AuthResult{
|
|
User: user,
|
|
Token: idToken,
|
|
RefreshToken: refreshToken,
|
|
}, nil
|
|
}
|
|
|
|
func (c *Client) LoginUser(ctx context.Context, username, password string) (*entities.AuthResult, error) {
|
|
mac := hmac.New(sha256.New, []byte(c.poolClientSecret))
|
|
mac.Write([]byte(username + c.poolClientID))
|
|
secretHash := base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
|
|
|
cognitoClient := cognitoidentityprovider.New(c.session)
|
|
|
|
res, err := cognitoClient.InitiateAuthWithContext(ctx, &cognitoidentityprovider.InitiateAuthInput{
|
|
AuthFlow: aws.String("USER_PASSWORD_AUTH"),
|
|
AuthParameters: map[string]*string{
|
|
"USERNAME": &username,
|
|
"PASSWORD": &password,
|
|
"SECRET_HASH": aws.String(secretHash),
|
|
},
|
|
ClientId: aws.String(c.poolClientID),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if res.ChallengeName != nil && *res.ChallengeName == "NEW_PASSWORD_REQUIRED" {
|
|
return &entities.AuthResult{
|
|
Session: *res.Session,
|
|
PasswordChangeRequired: true,
|
|
}, nil
|
|
} else if res.ChallengeName != nil {
|
|
return nil, models.NotImplementedError{
|
|
Function: "cognitoauth.Client",
|
|
Message: "Missing handler for challenge: " + *res.ChallengeName,
|
|
}
|
|
}
|
|
|
|
if res.AuthenticationResult == nil || res.AuthenticationResult.IdToken == nil {
|
|
return nil, models.PermissionDeniedError{}
|
|
}
|
|
|
|
idToken := *res.AuthenticationResult.IdToken
|
|
|
|
refreshToken := ""
|
|
if res.AuthenticationResult.RefreshToken != nil {
|
|
refreshToken = *res.AuthenticationResult.RefreshToken
|
|
}
|
|
|
|
return &entities.AuthResult{
|
|
User: c.ValidateToken(nil, idToken),
|
|
Token: idToken,
|
|
RefreshToken: refreshToken,
|
|
}, nil
|
|
}
|
|
|
|
func (c *Client) SetupUser(ctx context.Context, session, username, newPassword string) (*entities.User, error) {
|
|
mac := hmac.New(sha256.New, []byte(c.poolClientSecret))
|
|
mac.Write([]byte(username + c.poolClientID))
|
|
secretHash := base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
|
|
|
cognitoClient := cognitoidentityprovider.New(c.session)
|
|
|
|
res, err := cognitoClient.RespondToAuthChallengeWithContext(ctx, &cognitoidentityprovider.RespondToAuthChallengeInput{
|
|
ChallengeName: aws.String("NEW_PASSWORD_REQUIRED"),
|
|
ChallengeResponses: map[string]*string{
|
|
"NEW_PASSWORD": aws.String(newPassword),
|
|
"USERNAME": aws.String(username),
|
|
"SECRET_HASH": aws.String(secretHash),
|
|
},
|
|
Session: aws.String(session),
|
|
ClientId: aws.String(c.poolClientID),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if res.AuthenticationResult == nil || res.AuthenticationResult.IdToken == nil {
|
|
return nil, models.PermissionDeniedError{}
|
|
}
|
|
|
|
idToken := *res.AuthenticationResult.IdToken
|
|
|
|
return c.ValidateToken(ctx, idToken), nil
|
|
}
|
|
|
|
func (c *Client) parseToken(token string) *entities.User {
|
|
split := strings.SplitN(token, ".", 3)
|
|
if len(split) != 3 {
|
|
return nil
|
|
}
|
|
|
|
data, err := base64.RawStdEncoding.DecodeString(split[1])
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
m := make(map[string]interface{}, 16)
|
|
err = json.Unmarshal(data, &m)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
userID, _ := m["sub"].(string)
|
|
if sub, ok := m["custom:actual_userid"].(string); ok {
|
|
userID = sub
|
|
}
|
|
userName, _ := m["cognito:username"].(string)
|
|
|
|
return &entities.User{
|
|
ID: userID,
|
|
Name: userName,
|
|
}
|
|
}
|
|
|
|
func (c *Client) ValidateToken(_ context.Context, token string) *entities.User {
|
|
_, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
|
|
if token.Method.Alg() != jwa.RS256.String() { // jwa.RS256.String() works as well
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
kid, ok := token.Header["kid"].(string)
|
|
if !ok {
|
|
return nil, errors.New("kid header not found")
|
|
}
|
|
key, ok := c.keySet.LookupKeyID(kid)
|
|
if !ok {
|
|
return nil, fmt.Errorf("key %v not found", kid)
|
|
}
|
|
var raw interface{}
|
|
err := key.Raw(&raw)
|
|
return raw, err
|
|
}, jwt.WithoutAudienceValidation())
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
return c.parseToken(token)
|
|
}
|
|
|
|
func New(regionId, clientID, clientSecret, poolId, poolClientId, poolClientSecret string) (*Client, error) {
|
|
s, err := session.NewSession(&aws.Config{
|
|
Region: aws.String(regionId),
|
|
Credentials: credentials.NewStaticCredentials(
|
|
clientID,
|
|
clientSecret,
|
|
"",
|
|
),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
keySet, err := jwk.Fetch(context.Background(), fmt.Sprintf(
|
|
"https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json",
|
|
regionId,
|
|
poolId,
|
|
))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Client{
|
|
poolID: poolId,
|
|
poolClientID: poolClientId,
|
|
poolClientSecret: poolClientSecret,
|
|
session: s,
|
|
keySet: keySet,
|
|
}, nil
|
|
}
|