197 lines
4.5 KiB
Go
197 lines
4.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt"
|
|
"github.com/rs/xid"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/spf13/viper"
|
|
"kumoly.io/kumoly/app/errors"
|
|
"kumoly.io/kumoly/app/system"
|
|
)
|
|
|
|
const GinClaimKey = "claim"
|
|
const GinUserKey = "user"
|
|
|
|
type Claims struct {
|
|
Uid string `json:"uid,omitempty"`
|
|
User string `json:"usr,omitempty"`
|
|
Groups []string `json:"grp,omitempty"`
|
|
Endpoint string `json:"ept,omitempty"`
|
|
IP string `json:"ip,omitempty"`
|
|
jwt.StandardClaims
|
|
}
|
|
|
|
type Auth struct {
|
|
system.BaseService
|
|
CookieMode bool
|
|
CookieSecure bool
|
|
CookieSameSite http.SameSite
|
|
TokenExpire int64
|
|
Endpoint string
|
|
AutoRenew bool
|
|
Secret string
|
|
}
|
|
|
|
func NewAuth() *Auth {
|
|
return &Auth{
|
|
CookieMode: true,
|
|
CookieSecure: strings.HasPrefix(viper.GetString("server.url"), "https"),
|
|
CookieSameSite: http.SameSiteLaxMode,
|
|
TokenExpire: viper.GetInt64("auth.expire"),
|
|
AutoRenew: true,
|
|
Secret: viper.GetString("auth.secret"),
|
|
}
|
|
}
|
|
|
|
// Parse tok str to token object
|
|
func (srv Auth) Parse(tok string) (token *jwt.Token, err error) {
|
|
token, err = jwt.ParseWithClaims(tok, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(srv.Secret), nil
|
|
})
|
|
return
|
|
}
|
|
|
|
// ParseClaims parse token string to claims object
|
|
func (srv Auth) ParseClaims(tok string) (claims *Claims, err error) {
|
|
token, err := srv.Parse(tok)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
claims, ok := token.Claims.(*Claims)
|
|
if !ok {
|
|
err = errors.New(http.StatusBadRequest, "ErrorUnknownClaims")
|
|
}
|
|
return
|
|
}
|
|
|
|
// SetToken in header and cookie(if CookieMode)
|
|
func (srv Auth) SetToken(c *gin.Context, tok string) {
|
|
c.Header("Authorization", "Bearer "+tok)
|
|
if srv.CookieMode {
|
|
http.SetCookie(c.Writer, &http.Cookie{
|
|
Name: viper.GetString("name") + "_bearer",
|
|
MaxAge: int(srv.TokenExpire),
|
|
Value: tok,
|
|
SameSite: srv.CookieSameSite,
|
|
Secure: srv.CookieSecure,
|
|
Path: "/",
|
|
})
|
|
}
|
|
}
|
|
|
|
// SetClaims directly to response
|
|
func (srv Auth) SetClaims(c *gin.Context, claims *Claims) error {
|
|
tok, err := srv.NewToken(*claims)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
srv.SetToken(c, tok)
|
|
return nil
|
|
}
|
|
|
|
// GetToken from header or cookie
|
|
func (srv Auth) GetToken(c *gin.Context) (tok string, err error) {
|
|
tok = strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer ")
|
|
if tok == "" && srv.CookieMode {
|
|
tok, err = c.Cookie(viper.GetString("name") + "_bearer")
|
|
}
|
|
if err != nil {
|
|
err = nil
|
|
return
|
|
}
|
|
if tok == "" {
|
|
err = errors.New(401, "ErrorTokenNotFound")
|
|
return
|
|
}
|
|
return tok, nil
|
|
}
|
|
|
|
// GetClaims directly from http request
|
|
func (srv Auth) GetClaims(c *gin.Context) (claims *Claims, err error) {
|
|
tok, err := srv.GetToken(c)
|
|
if err != nil {
|
|
return
|
|
}
|
|
claims, err = srv.ParseClaims(tok)
|
|
return
|
|
}
|
|
|
|
func (srv Auth) ClearToken(c *gin.Context) {
|
|
c.Writer.Header().Del("Authorization")
|
|
if srv.CookieMode {
|
|
http.SetCookie(c.Writer, &http.Cookie{
|
|
Name: viper.GetString("name") + "_bearer",
|
|
MaxAge: -1,
|
|
Value: "",
|
|
SameSite: srv.CookieSameSite,
|
|
Secure: srv.CookieSecure,
|
|
Path: "/",
|
|
})
|
|
}
|
|
}
|
|
|
|
// New tok str
|
|
func (srv Auth) NewToken(claims Claims) (tok string, err error) {
|
|
if srv.TokenExpire > 0 && claims.ExpiresAt == 0 {
|
|
claims.ExpiresAt = time.Now().Unix() + srv.TokenExpire
|
|
} else if claims.ExpiresAt < 0 {
|
|
claims.ExpiresAt = 0
|
|
}
|
|
claims.Issuer = viper.GetString("name")
|
|
claims.Id = xid.New().String()
|
|
claims.Endpoint = viper.GetString("domain")
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tok, err = token.SignedString([]byte(srv.Secret))
|
|
if err != nil {
|
|
log.Error().Str("mod", "auth").Err(err).Msg("NewToken")
|
|
}
|
|
return
|
|
}
|
|
|
|
func (srv Auth) Middleware(c *gin.Context) {
|
|
claims, err := srv.GetClaims(c)
|
|
if err == nil {
|
|
c.Set(GinClaimKey, claims)
|
|
c.Set(GinUserKey, claims.User)
|
|
if srv.AutoRenew && claims.ExpiresAt != 0 {
|
|
claims.ExpiresAt = time.Now().Unix() + srv.TokenExpire
|
|
tok, err := srv.NewToken(*claims)
|
|
if err != nil {
|
|
log.Error().Str("mod", "auth").Err(err).Msg("Middleware")
|
|
} else {
|
|
srv.SetToken(c, tok)
|
|
}
|
|
}
|
|
} else {
|
|
log.Trace().Err(err).Msg("")
|
|
}
|
|
}
|
|
|
|
func (srv Auth) Injector(router *gin.RouterGroup) *system.Inject {
|
|
return &system.Inject{
|
|
Name: "auth.Auth",
|
|
InitFunc: func() error {
|
|
router.Use(srv.Middleware)
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
func GetContextClaims(c *gin.Context) (claims *Claims, err error) {
|
|
cl, ok := c.Get(GinClaimKey)
|
|
if !ok {
|
|
err = ErrorTokenNotValid
|
|
return
|
|
}
|
|
claims, ok = cl.(*Claims)
|
|
if !ok {
|
|
err = ErrorUnknownClaims
|
|
}
|
|
return
|
|
}
|