app/auth/jwt.go

194 lines
4.4 KiB
Go
Raw Permalink Normal View History

2021-12-16 04:11:33 +00:00
package auth
import (
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"github.com/rs/xid"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"kumoly.io/kumoly/app/errors"
"kumoly.io/kumoly/app/system"
)
const GinClaimKey = "claim"
type Claims struct {
Uid string `json:"uid,omitempty"`
Groups []string `json:"grp,omitempty"`
Endpoint string `json:"ept,omitempty"`
IP string `json:"ip,omitempty"`
jwt.StandardClaims
}
type Auth struct {
system.BaseService
CookieMode bool
CookieSecure bool
CookieSameSite http.SameSite
TokenExpire int64
Endpoint string
AutoRenew bool
Secret string
}
func NewAuth() *Auth {
return &Auth{
CookieMode: true,
CookieSecure: strings.HasPrefix(viper.GetString("server.url"), "https"),
CookieSameSite: http.SameSiteLaxMode,
TokenExpire: viper.GetInt64("auth.expire"),
AutoRenew: true,
Secret: viper.GetString("auth.secret"),
}
}
// Parse tok str to token object
func (srv Auth) Parse(tok string) (token *jwt.Token, err error) {
token, err = jwt.ParseWithClaims(tok, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(srv.Secret), nil
})
return
}
// ParseClaims parse token string to claims object
func (srv Auth) ParseClaims(tok string) (claims *Claims, err error) {
token, err := srv.Parse(tok)
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*Claims)
if !ok {
err = errors.New(http.StatusBadRequest, "ErrorUnknownClaims")
}
return
}
// SetToken in header and cookie(if CookieMode)
func (srv Auth) SetToken(c *gin.Context, tok string) {
c.Header("Authorization", "Bearer "+tok)
if srv.CookieMode {
http.SetCookie(c.Writer, &http.Cookie{
Name: viper.GetString("name") + "_bearer",
MaxAge: int(srv.TokenExpire),
Value: tok,
SameSite: srv.CookieSameSite,
Secure: srv.CookieSecure,
Path: "/",
})
}
}
// SetClaims directly to response
func (srv Auth) SetClaims(c *gin.Context, claims *Claims) error {
tok, err := srv.NewToken(*claims)
if err != nil {
return err
}
srv.SetToken(c, tok)
return nil
}
// GetToken from header or cookie
func (srv Auth) GetToken(c *gin.Context) (tok string, err error) {
tok = strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer ")
if tok == "" && srv.CookieMode {
tok, err = c.Cookie(viper.GetString("name") + "_bearer")
}
if err != nil {
err = nil
return
}
if tok == "" {
err = errors.New(401, "ErrorTokenNotFound")
return
}
return tok, nil
}
// GetClaims directly from http request
func (srv Auth) GetClaims(c *gin.Context) (claims *Claims, err error) {
tok, err := srv.GetToken(c)
if err != nil {
return
}
claims, err = srv.ParseClaims(tok)
return
}
func (srv Auth) ClearToken(c *gin.Context) {
c.Writer.Header().Del("Authorization")
if srv.CookieMode {
http.SetCookie(c.Writer, &http.Cookie{
Name: viper.GetString("name") + "_bearer",
MaxAge: -1,
Value: "",
SameSite: srv.CookieSameSite,
Secure: srv.CookieSecure,
Path: "/",
})
}
}
// New tok str
func (srv Auth) NewToken(claims Claims) (tok string, err error) {
if srv.TokenExpire > 0 && claims.ExpiresAt == 0 {
claims.ExpiresAt = time.Now().Unix() + srv.TokenExpire
} else if claims.ExpiresAt < 0 {
claims.ExpiresAt = 0
}
claims.Issuer = viper.GetString("name")
claims.Id = xid.New().String()
claims.Endpoint = viper.GetString("domain")
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tok, err = token.SignedString([]byte(srv.Secret))
if err != nil {
log.Error().Str("mod", "auth").Err(err).Msg("NewToken")
}
return
}
func (srv Auth) Middleware(c *gin.Context) {
claims, err := srv.GetClaims(c)
if err == nil {
c.Set(GinClaimKey, claims)
if srv.AutoRenew && claims.ExpiresAt != 0 {
claims.ExpiresAt = time.Now().Unix() + srv.TokenExpire
tok, err := srv.NewToken(*claims)
if err != nil {
log.Error().Str("mod", "auth").Err(err).Msg("Middleware")
} else {
srv.SetToken(c, tok)
}
}
} else {
log.Trace().Err(err).Msg("")
}
}
func (srv Auth) Injector(router *gin.RouterGroup) *system.Inject {
return &system.Inject{
Name: "auth.Auth",
InitFunc: func() error {
router.Use(srv.Middleware)
return nil
},
}
}
func GetContextClaims(c *gin.Context) (claims *Claims, err error) {
cl, ok := c.Get(GinClaimKey)
if !ok {
err = ErrorTokenNotValid
return
}
claims, ok = cl.(*Claims)
if !ok {
err = ErrorUnknownClaims
}
return
}