feat: add default skip static
parent
e27c6e7f45
commit
cd95cf5454
7
guard.go
7
guard.go
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"kumoly.io/lib/guard/netutil"
|
||||
)
|
||||
|
||||
type Guard struct {
|
||||
|
@ -67,7 +68,7 @@ func (g *Guard) Guard(next http.Handler) http.Handler {
|
|||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
rw := &responseWriter{w, 0, ""}
|
||||
ip := GetIP(r)
|
||||
ip := netutil.GetIP(r)
|
||||
defer func() {
|
||||
var cl *zerolog.Event
|
||||
err := recover()
|
||||
|
@ -76,7 +77,7 @@ func (g *Guard) Guard(next http.Handler) http.Handler {
|
|||
switch v := err.(type) {
|
||||
case Error:
|
||||
rw.WriteHeader(v.Code)
|
||||
JSON(rw, v)
|
||||
netutil.JSON(rw, v)
|
||||
cl.Err(v)
|
||||
case error:
|
||||
rw.WriteHeader(500)
|
||||
|
@ -108,7 +109,7 @@ func (g *Guard) Guard(next http.Handler) http.Handler {
|
|||
if g.AllowIPNet != nil && !g.AllowIPNet.Contains(net.ParseIP(ip)) {
|
||||
panic(ErrorForbidden)
|
||||
}
|
||||
if g.AllowIPGlob != "" && MatchIPGlob(ip, g.AllowIPGlob) {
|
||||
if g.AllowIPGlob != "" && netutil.MatchIPGlob(ip, g.AllowIPGlob) {
|
||||
panic(ErrorForbidden)
|
||||
}
|
||||
if g.basicAuth != "" {
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
package netutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// JSON shorthand for json response
|
||||
func JSON(w http.ResponseWriter, value interface{}) (int, error) {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return w.Write(data)
|
||||
}
|
||||
|
||||
// MatchIPGlob match ip to glob pattern, ex. * 192.168.* 192.* 192.168.51.*2*
|
||||
func MatchIPGlob(ip, pattern string) bool {
|
||||
parts := strings.Split(pattern, ".")
|
||||
seg := strings.Split(ip, ".")
|
||||
for i, part := range parts {
|
||||
|
||||
// normalize pattern to 3 digits
|
||||
switch len(part) {
|
||||
case 1:
|
||||
if part == "*" {
|
||||
part = "***"
|
||||
} else {
|
||||
part = "00" + part
|
||||
}
|
||||
case 2:
|
||||
if strings.HasPrefix(part, "*") {
|
||||
part = "*" + part
|
||||
} else if strings.HasSuffix(part, "*") {
|
||||
part = part + "*"
|
||||
} else {
|
||||
part = "0" + part
|
||||
}
|
||||
}
|
||||
|
||||
// normalize ip to 3 digits
|
||||
switch len(seg[i]) {
|
||||
case 1:
|
||||
seg[i] = "00" + seg[i]
|
||||
case 2:
|
||||
seg[i] = "0" + seg[i]
|
||||
}
|
||||
|
||||
for j := range part {
|
||||
if string(part[j]) == "*" {
|
||||
continue
|
||||
}
|
||||
if part[j] != seg[i][j] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// GetIP gets the real ip (could still be tricked by proxy)
|
||||
func GetIP(r *http.Request) string {
|
||||
ip := r.Header.Get("X-Real-Ip")
|
||||
if ip == "" {
|
||||
ips := r.Header.Get("X-Forwarded-For")
|
||||
ipArr := strings.Split(ips, ",")
|
||||
ip = strings.Trim(ipArr[len(ipArr)-1], " ")
|
||||
}
|
||||
if ip == "" {
|
||||
var err error
|
||||
ip, _, err = net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ip = r.RemoteAddr
|
||||
}
|
||||
}
|
||||
return ip
|
||||
}
|
79
util.go
79
util.go
|
@ -1,80 +1,17 @@
|
|||
package guard
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// JSON shorthand for json response
|
||||
func JSON(w http.ResponseWriter, value interface{}) (int, error) {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return w.Write(data)
|
||||
}
|
||||
|
||||
// MatchIPGlob match ip to glob pattern, ex. * 192.168.* 192.* 192.168.51.*2*
|
||||
func MatchIPGlob(ip, pattern string) bool {
|
||||
parts := strings.Split(pattern, ".")
|
||||
seg := strings.Split(ip, ".")
|
||||
for i, part := range parts {
|
||||
|
||||
// normalize pattern to 3 digits
|
||||
switch len(part) {
|
||||
case 1:
|
||||
if part == "*" {
|
||||
part = "***"
|
||||
} else {
|
||||
part = "00" + part
|
||||
}
|
||||
case 2:
|
||||
if strings.HasPrefix(part, "*") {
|
||||
part = "*" + part
|
||||
} else if strings.HasSuffix(part, "*") {
|
||||
part = part + "*"
|
||||
} else {
|
||||
part = "0" + part
|
||||
}
|
||||
}
|
||||
|
||||
// normalize ip to 3 digits
|
||||
switch len(seg[i]) {
|
||||
case 1:
|
||||
seg[i] = "00" + seg[i]
|
||||
case 2:
|
||||
seg[i] = "0" + seg[i]
|
||||
}
|
||||
|
||||
for j := range part {
|
||||
if string(part[j]) == "*" {
|
||||
continue
|
||||
}
|
||||
if part[j] != seg[i][j] {
|
||||
func SkipStatic(r *http.Request) bool {
|
||||
switch filepath.Ext(r.URL.Path) {
|
||||
case ".js":
|
||||
fallthrough
|
||||
case ".css":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// GetIP gets the real ip (could still be tricked by proxy)
|
||||
func GetIP(r *http.Request) string {
|
||||
ip := r.Header.Get("X-Real-Ip")
|
||||
if ip == "" {
|
||||
ips := r.Header.Get("X-Forwarded-For")
|
||||
ipArr := strings.Split(ips, ",")
|
||||
ip = strings.Trim(ipArr[len(ipArr)-1], " ")
|
||||
}
|
||||
if ip == "" {
|
||||
var err error
|
||||
ip, _, err = net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ip = r.RemoteAddr
|
||||
}
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue