feat: add default skip static

master
Evan Chen 2021-11-19 20:09:50 +08:00
parent e27c6e7f45
commit cd95cf5454
3 changed files with 93 additions and 75 deletions

View File

@ -11,6 +11,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"kumoly.io/lib/guard/netutil"
) )
type Guard struct { type Guard struct {
@ -67,7 +68,7 @@ func (g *Guard) Guard(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()
rw := &responseWriter{w, 0, ""} rw := &responseWriter{w, 0, ""}
ip := GetIP(r) ip := netutil.GetIP(r)
defer func() { defer func() {
var cl *zerolog.Event var cl *zerolog.Event
err := recover() err := recover()
@ -76,7 +77,7 @@ func (g *Guard) Guard(next http.Handler) http.Handler {
switch v := err.(type) { switch v := err.(type) {
case Error: case Error:
rw.WriteHeader(v.Code) rw.WriteHeader(v.Code)
JSON(rw, v) netutil.JSON(rw, v)
cl.Err(v) cl.Err(v)
case error: case error:
rw.WriteHeader(500) rw.WriteHeader(500)
@ -108,7 +109,7 @@ func (g *Guard) Guard(next http.Handler) http.Handler {
if g.AllowIPNet != nil && !g.AllowIPNet.Contains(net.ParseIP(ip)) { if g.AllowIPNet != nil && !g.AllowIPNet.Contains(net.ParseIP(ip)) {
panic(ErrorForbidden) panic(ErrorForbidden)
} }
if g.AllowIPGlob != "" && MatchIPGlob(ip, g.AllowIPGlob) { if g.AllowIPGlob != "" && netutil.MatchIPGlob(ip, g.AllowIPGlob) {
panic(ErrorForbidden) panic(ErrorForbidden)
} }
if g.basicAuth != "" { if g.basicAuth != "" {

80
netutil/util.go Normal file
View File

@ -0,0 +1,80 @@
package netutil
import (
"encoding/json"
"net"
"net/http"
"strings"
)
// JSON shorthand for json response
func JSON(w http.ResponseWriter, value interface{}) (int, error) {
data, err := json.Marshal(value)
if err != nil {
panic(err)
}
w.Header().Set("Content-Type", "application/json")
return w.Write(data)
}
// MatchIPGlob match ip to glob pattern, ex. * 192.168.* 192.* 192.168.51.*2*
func MatchIPGlob(ip, pattern string) bool {
parts := strings.Split(pattern, ".")
seg := strings.Split(ip, ".")
for i, part := range parts {
// normalize pattern to 3 digits
switch len(part) {
case 1:
if part == "*" {
part = "***"
} else {
part = "00" + part
}
case 2:
if strings.HasPrefix(part, "*") {
part = "*" + part
} else if strings.HasSuffix(part, "*") {
part = part + "*"
} else {
part = "0" + part
}
}
// normalize ip to 3 digits
switch len(seg[i]) {
case 1:
seg[i] = "00" + seg[i]
case 2:
seg[i] = "0" + seg[i]
}
for j := range part {
if string(part[j]) == "*" {
continue
}
if part[j] != seg[i][j] {
return false
}
}
}
return true
}
// GetIP gets the real ip (could still be tricked by proxy)
func GetIP(r *http.Request) string {
ip := r.Header.Get("X-Real-Ip")
if ip == "" {
ips := r.Header.Get("X-Forwarded-For")
ipArr := strings.Split(ips, ",")
ip = strings.Trim(ipArr[len(ipArr)-1], " ")
}
if ip == "" {
var err error
ip, _, err = net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip = r.RemoteAddr
}
}
return ip
}

79
util.go
View File

@ -1,80 +1,17 @@
package guard package guard
import ( import (
"encoding/json"
"net"
"net/http" "net/http"
"strings" "path/filepath"
) )
// JSON shorthand for json response func SkipStatic(r *http.Request) bool {
func JSON(w http.ResponseWriter, value interface{}) (int, error) { switch filepath.Ext(r.URL.Path) {
data, err := json.Marshal(value) case ".js":
if err != nil { fallthrough
panic(err) case ".css":
} return true
w.Header().Set("Content-Type", "application/json") default:
return w.Write(data)
}
// MatchIPGlob match ip to glob pattern, ex. * 192.168.* 192.* 192.168.51.*2*
func MatchIPGlob(ip, pattern string) bool {
parts := strings.Split(pattern, ".")
seg := strings.Split(ip, ".")
for i, part := range parts {
// normalize pattern to 3 digits
switch len(part) {
case 1:
if part == "*" {
part = "***"
} else {
part = "00" + part
}
case 2:
if strings.HasPrefix(part, "*") {
part = "*" + part
} else if strings.HasSuffix(part, "*") {
part = part + "*"
} else {
part = "0" + part
}
}
// normalize ip to 3 digits
switch len(seg[i]) {
case 1:
seg[i] = "00" + seg[i]
case 2:
seg[i] = "0" + seg[i]
}
for j := range part {
if string(part[j]) == "*" {
continue
}
if part[j] != seg[i][j] {
return false return false
} }
} }
}
return true
}
// GetIP gets the real ip (could still be tricked by proxy)
func GetIP(r *http.Request) string {
ip := r.Header.Get("X-Real-Ip")
if ip == "" {
ips := r.Header.Get("X-Forwarded-For")
ipArr := strings.Split(ips, ",")
ip = strings.Trim(ipArr[len(ipArr)-1], " ")
}
if ip == "" {
var err error
ip, _, err = net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip = r.RemoteAddr
}
}
return ip
}