package auth import ( "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" "github.com/rs/xid" "github.com/rs/zerolog/log" "github.com/spf13/viper" "kumoly.io/kumoly/app/errors" "kumoly.io/kumoly/app/system" ) const GinClaimKey = "claim" type Claims struct { Uid string `json:"uid,omitempty"` Groups []string `json:"grp,omitempty"` Endpoint string `json:"ept,omitempty"` IP string `json:"ip,omitempty"` jwt.StandardClaims } type Auth struct { system.BaseService CookieMode bool CookieSecure bool CookieSameSite http.SameSite TokenExpire int64 Endpoint string AutoRenew bool Secret string } func NewAuth() *Auth { return &Auth{ CookieMode: true, CookieSecure: strings.HasPrefix(viper.GetString("server.url"), "https"), CookieSameSite: http.SameSiteLaxMode, TokenExpire: viper.GetInt64("auth.expire"), AutoRenew: true, Secret: viper.GetString("auth.secret"), } } // Parse tok str to token object func (srv Auth) Parse(tok string) (token *jwt.Token, err error) { token, err = jwt.ParseWithClaims(tok, &Claims{}, func(token *jwt.Token) (interface{}, error) { return []byte(srv.Secret), nil }) return } // ParseClaims parse token string to claims object func (srv Auth) ParseClaims(tok string) (claims *Claims, err error) { token, err := srv.Parse(tok) if err != nil { return nil, err } claims, ok := token.Claims.(*Claims) if !ok { err = errors.New(http.StatusBadRequest, "ErrorUnknownClaims") } return } // SetToken in header and cookie(if CookieMode) func (srv Auth) SetToken(c *gin.Context, tok string) { c.Header("Authorization", "Bearer "+tok) if srv.CookieMode { http.SetCookie(c.Writer, &http.Cookie{ Name: viper.GetString("name") + "_bearer", MaxAge: int(srv.TokenExpire), Value: tok, SameSite: srv.CookieSameSite, Secure: srv.CookieSecure, Path: "/", }) } } // SetClaims directly to response func (srv Auth) SetClaims(c *gin.Context, claims *Claims) error { tok, err := srv.NewToken(*claims) if err != nil { return err } srv.SetToken(c, tok) return nil } // GetToken from header or cookie func (srv Auth) GetToken(c *gin.Context) (tok string, err error) { tok = strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer ") if tok == "" && srv.CookieMode { tok, err = c.Cookie(viper.GetString("name") + "_bearer") } if err != nil { err = nil return } if tok == "" { err = errors.New(401, "ErrorTokenNotFound") return } return tok, nil } // GetClaims directly from http request func (srv Auth) GetClaims(c *gin.Context) (claims *Claims, err error) { tok, err := srv.GetToken(c) if err != nil { return } claims, err = srv.ParseClaims(tok) return } func (srv Auth) ClearToken(c *gin.Context) { c.Writer.Header().Del("Authorization") if srv.CookieMode { http.SetCookie(c.Writer, &http.Cookie{ Name: viper.GetString("name") + "_bearer", MaxAge: -1, Value: "", SameSite: srv.CookieSameSite, Secure: srv.CookieSecure, Path: "/", }) } } // New tok str func (srv Auth) NewToken(claims Claims) (tok string, err error) { if srv.TokenExpire > 0 && claims.ExpiresAt == 0 { claims.ExpiresAt = time.Now().Unix() + srv.TokenExpire } else if claims.ExpiresAt < 0 { claims.ExpiresAt = 0 } claims.Issuer = viper.GetString("name") claims.Id = xid.New().String() claims.Endpoint = viper.GetString("domain") token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tok, err = token.SignedString([]byte(srv.Secret)) if err != nil { log.Error().Str("mod", "auth").Err(err).Msg("NewToken") } return } func (srv Auth) Middleware(c *gin.Context) { claims, err := srv.GetClaims(c) if err == nil { c.Set(GinClaimKey, claims) if srv.AutoRenew && claims.ExpiresAt != 0 { claims.ExpiresAt = time.Now().Unix() + srv.TokenExpire tok, err := srv.NewToken(*claims) if err != nil { log.Error().Str("mod", "auth").Err(err).Msg("Middleware") } else { srv.SetToken(c, tok) } } } else { log.Trace().Err(err).Msg("") } } func (srv Auth) Injector(router *gin.RouterGroup) *system.Inject { return &system.Inject{ Name: "auth.Auth", InitFunc: func() error { router.Use(srv.Middleware) return nil }, } } func GetContextClaims(c *gin.Context) (claims *Claims, err error) { cl, ok := c.Get(GinClaimKey) if !ok { err = ErrorTokenNotValid return } claims, ok = cl.(*Claims) if !ok { err = ErrorUnknownClaims } return }