app/auth/helper.go

202 lines
4.3 KiB
Go
Raw Permalink Normal View History

2021-12-16 04:11:33 +00:00
package auth
import (
"github.com/gin-gonic/gin"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// HasGroup checks if a group is in the claim groups
func (c *Claims) HasGroup(grps ...string) bool {
m := make(map[string]bool)
for _, grp := range grps {
m[grp] = true
}
for _, trg := range c.Groups {
if m[trg] {
return true
}
}
return false
}
func (srv Service) SetDefaultGroups() error {
srv.Logger.Debug().Msg("Setup default groups")
for _, g := range []string{SYSTEM, ADMIN, USER} {
grp := &Group{}
if err := DB.Where("name = ?", g).First(grp).Error; err != nil {
err := DB.Create(&Group{
Name: g,
}).Error
if err != nil {
srv.Logger.Error().Err(err).Msg("create group error")
return err
}
}
}
return nil
}
func (srv Service) SetDefaultAdmin(username, password string) error {
admin := &Group{}
err := DB.Where("name = ?", ADMIN).First(admin).Error
if err != nil {
srv.Logger.Error().Err(err).Msg("SetDefaultAdmin")
return err
}
usrgrp := struct {
GroupID uint
UserID string
}{}
result := DB.
Raw("select * from user_groups where group_id = ?", admin.ID).
Scan(&usrgrp)
if result.Error != nil {
srv.Logger.Error().Err(result.Error).Msg("SetDefaultAdmin")
return result.Error
}
usr := &User{}
if result.RowsAffected == 0 {
srv.Logger.Debug().Msg("Setting up admin account")
pwd, _ := bcrypt.GenerateFromPassword([]byte(password), 14)
usr.Username = username
usr.Password = string(pwd)
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Create(usr).Error
if err != nil {
return err
}
grp := &Group{
Name: username,
}
err = tx.Create(grp).Error
if err != nil {
return err
}
err = tx.Model(usr).Association("Groups").Append(admin, grp)
if err != nil {
return err
}
profile := &Profile{
DisplayName: username,
}
err = tx.Model(usr).Association("Profile").Append(profile)
if err != nil {
return err
}
return nil
})
if err != nil {
return err
}
SetGroup(usr.ID, USER, true)
}
return nil
}
func IsLastAdmin() bool {
var count int
DB.Raw(`
select count(*) from user_groups where group_id = (
select id from groups where name = ?
)
`, ADMIN).Scan(&count)
return count == 1
}
func IsLastAdminUser(uid string) bool {
var count int
DB.Raw(`
select count(*) from user_groups ug, users u
where group_id = (select id from groups where name = ?)
and user_id <> ?
and u.id = user_id
and u.activated = 1
`, ADMIN, uid).Scan(&count)
return count == 0
}
func SetAdmin(uid string, set bool) {
SetGroup(uid, ADMIN, set)
}
func SetGroup(uid string, group_name string, set bool) {
var grp_id uint
DB.Raw(
`select id from groups where name = ?`,
group_name,
).Scan(&grp_id)
var count int
DB.Raw(`
select count(*) from user_groups where group_id = ?
and user_id = ?`,
grp_id, uid).Scan(&count)
// remove
if count == 1 && !set {
DB.Exec(`delete from user_groups where group_id = ?
and user_id = ?`, grp_id, uid)
}
// add
if count == 0 && set {
DB.Exec(`insert into user_groups (group_id, user_id)
values (?, ?)`, grp_id, uid)
}
}
func GetUser(c *gin.Context, db *gorm.DB) (*User, error) {
claim, err := GetContextClaims(c)
if err != nil {
return nil, err
}
usr := &User{}
err = db.Preload("Groups").Preload("Profile").Where("id = ?", claim.Uid).First(usr).Error
if err != nil {
return nil, err
}
return usr, nil
}
// NewUser the password is still not hashed
func NewUser(usr *User, db *gorm.DB) error {
if usr.Username == "" || usr.Password == "" {
return ErrorBadRequestTmpl.New("auth.User")
}
bytes, err := bcrypt.GenerateFromPassword([]byte(usr.Password), 14)
if err != nil {
return err
}
usr.Password = string(bytes)
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Create(usr).Error
if err != nil {
return err
}
grp := &Group{
Name: usr.Username,
}
err = tx.Create(grp).Error
if err != nil {
return err
}
grp_user := &Group{}
err = DB.Where("name = ?", USER).First(grp_user).Error
if err != nil {
return err
}
err = tx.Model(usr).Association("Groups").Append(grp, grp_user)
if err != nil {
return err
}
profile := &Profile{
DisplayName: usr.Username,
}
err = tx.Model(usr).Association("Profile").Append(profile)
if err != nil {
return err
}
return nil
})
return err
}