Loggest thine Stuff
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

261 lines
6.8 KiB

package cognitoauth
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"git.aiterp.net/stufflog3/stufflog3/entities"
"git.aiterp.net/stufflog3/stufflog3/models"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cognitoidentityprovider"
"github.com/dgrijalva/jwt-go/v4"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"strings"
)
type Client struct {
poolID string
poolClientID string
poolClientSecret string
session *session.Session
keySet jwk.Set
}
func (c *Client) ListUsers(ctx context.Context) ([]entities.User, error) {
cognitoClient := cognitoidentityprovider.New(c.session)
res, err := cognitoClient.ListUsersWithContext(ctx, &cognitoidentityprovider.ListUsersInput{
UserPoolId: aws.String(c.poolID),
})
if err != nil {
return nil, err
}
users := make([]entities.User, 0, 16)
for _, u := range res.Users {
user := entities.User{}
for _, attr := range u.Attributes {
switch *attr.Name {
case "sub":
user.ID = *attr.Value
}
}
users = append(users, user)
}
return users, nil
}
func (c *Client) RefreshUser(ctx context.Context, token, refreshToken string) (*entities.AuthResult, error) {
user := c.parseToken(token)
if user == nil {
return nil, models.PermissionDeniedError{}
}
mac := hmac.New(sha256.New, []byte(c.poolClientSecret))
mac.Write([]byte(user.Name + c.poolClientID))
secretHash := base64.StdEncoding.EncodeToString(mac.Sum(nil))
cognitoClient := cognitoidentityprovider.New(c.session)
res, err := cognitoClient.InitiateAuthWithContext(ctx, &cognitoidentityprovider.InitiateAuthInput{
AuthFlow: aws.String("REFRESH_TOKEN_AUTH"),
AuthParameters: map[string]*string{
"REFRESH_TOKEN": &refreshToken,
"SECRET_HASH": &secretHash,
},
ClientId: aws.String(c.poolClientID),
})
if err != nil {
return nil, err
}
if res.AuthenticationResult == nil || res.AuthenticationResult.IdToken == nil {
return nil, models.PermissionDeniedError{}
}
idToken := *res.AuthenticationResult.IdToken
if res.AuthenticationResult.RefreshToken != nil {
refreshToken = *res.AuthenticationResult.RefreshToken
} else {
refreshToken = ""
}
return &entities.AuthResult{
User: user,
Token: idToken,
RefreshToken: refreshToken,
}, nil
}
func (c *Client) LoginUser(ctx context.Context, username, password string) (*entities.AuthResult, error) {
mac := hmac.New(sha256.New, []byte(c.poolClientSecret))
mac.Write([]byte(username + c.poolClientID))
secretHash := base64.StdEncoding.EncodeToString(mac.Sum(nil))
cognitoClient := cognitoidentityprovider.New(c.session)
res, err := cognitoClient.InitiateAuthWithContext(ctx, &cognitoidentityprovider.InitiateAuthInput{
AuthFlow: aws.String("USER_PASSWORD_AUTH"),
AuthParameters: map[string]*string{
"USERNAME": &username,
"PASSWORD": &password,
"SECRET_HASH": aws.String(secretHash),
},
ClientId: aws.String(c.poolClientID),
})
if err != nil {
return nil, err
}
if res.ChallengeName != nil && *res.ChallengeName == "NEW_PASSWORD_REQUIRED" {
return &entities.AuthResult{
Session: *res.Session,
PasswordChangeRequired: true,
}, nil
} else if res.ChallengeName != nil {
return nil, models.NotImplementedError{
Function: "cognitoauth.Client",
Message: "Missing handler for challenge: " + *res.ChallengeName,
}
}
if res.AuthenticationResult == nil || res.AuthenticationResult.IdToken == nil {
return nil, models.PermissionDeniedError{}
}
idToken := *res.AuthenticationResult.IdToken
refreshToken := ""
if res.AuthenticationResult.RefreshToken != nil {
refreshToken = *res.AuthenticationResult.RefreshToken
}
return &entities.AuthResult{
User: c.ValidateToken(nil, idToken),
Token: idToken,
RefreshToken: refreshToken,
}, nil
}
func (c *Client) SetupUser(ctx context.Context, session, username, newPassword string) (*entities.User, error) {
mac := hmac.New(sha256.New, []byte(c.poolClientSecret))
mac.Write([]byte(username + c.poolClientID))
secretHash := base64.StdEncoding.EncodeToString(mac.Sum(nil))
cognitoClient := cognitoidentityprovider.New(c.session)
res, err := cognitoClient.RespondToAuthChallengeWithContext(ctx, &cognitoidentityprovider.RespondToAuthChallengeInput{
ChallengeName: aws.String("NEW_PASSWORD_REQUIRED"),
ChallengeResponses: map[string]*string{
"NEW_PASSWORD": aws.String(newPassword),
"USERNAME": aws.String(username),
"SECRET_HASH": aws.String(secretHash),
},
Session: aws.String(session),
ClientId: aws.String(c.poolClientID),
})
if err != nil {
return nil, err
}
if res.AuthenticationResult == nil || res.AuthenticationResult.IdToken == nil {
return nil, models.PermissionDeniedError{}
}
idToken := *res.AuthenticationResult.IdToken
return c.ValidateToken(ctx, idToken), nil
}
func (c *Client) parseToken(token string) *entities.User {
split := strings.SplitN(token, ".", 3)
if len(split) != 3 {
return nil
}
data, err := base64.RawStdEncoding.DecodeString(split[1])
if err != nil {
return nil
}
m := make(map[string]interface{}, 16)
err = json.Unmarshal(data, &m)
if err != nil {
return nil
}
userID, _ := m["sub"].(string)
if sub, ok := m["custom:actual_userid"].(string); ok {
userID = sub
}
userName, _ := m["cognito:username"].(string)
return &entities.User{
ID: userID,
Name: userName,
}
}
func (c *Client) ValidateToken(_ context.Context, token string) *entities.User {
_, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
if token.Method.Alg() != jwa.RS256.String() { // jwa.RS256.String() works as well
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, errors.New("kid header not found")
}
key, ok := c.keySet.LookupKeyID(kid)
if !ok {
return nil, fmt.Errorf("key %v not found", kid)
}
var raw interface{}
err := key.Raw(&raw)
return raw, err
}, jwt.WithoutAudienceValidation())
if err != nil {
return nil
}
return c.parseToken(token)
}
func New(regionId, clientID, clientSecret, poolId, poolClientId, poolClientSecret string) (*Client, error) {
s, err := session.NewSession(&aws.Config{
Region: aws.String(regionId),
Credentials: credentials.NewStaticCredentials(
clientID,
clientSecret,
"",
),
})
if err != nil {
return nil, err
}
keySet, err := jwk.Fetch(context.Background(), fmt.Sprintf(
"https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json",
regionId,
poolId,
))
if err != nil {
return nil, err
}
return &Client{
poolID: poolId,
poolClientID: poolClientId,
poolClientSecret: poolClientSecret,
session: s,
keySet: keySet,
}, nil
}