diff --git a/guard.go b/guard.go index f3fa9d2..f92237a 100644 --- a/guard.go +++ b/guard.go @@ -11,6 +11,7 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "kumoly.io/lib/guard/netutil" ) type Guard struct { @@ -67,7 +68,7 @@ func (g *Guard) Guard(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() rw := &responseWriter{w, 0, ""} - ip := GetIP(r) + ip := netutil.GetIP(r) defer func() { var cl *zerolog.Event err := recover() @@ -76,7 +77,7 @@ func (g *Guard) Guard(next http.Handler) http.Handler { switch v := err.(type) { case Error: rw.WriteHeader(v.Code) - JSON(rw, v) + netutil.JSON(rw, v) cl.Err(v) case error: rw.WriteHeader(500) @@ -108,7 +109,7 @@ func (g *Guard) Guard(next http.Handler) http.Handler { if g.AllowIPNet != nil && !g.AllowIPNet.Contains(net.ParseIP(ip)) { panic(ErrorForbidden) } - if g.AllowIPGlob != "" && MatchIPGlob(ip, g.AllowIPGlob) { + if g.AllowIPGlob != "" && netutil.MatchIPGlob(ip, g.AllowIPGlob) { panic(ErrorForbidden) } if g.basicAuth != "" { diff --git a/netutil/util.go b/netutil/util.go new file mode 100644 index 0000000..2756010 --- /dev/null +++ b/netutil/util.go @@ -0,0 +1,80 @@ +package netutil + +import ( + "encoding/json" + "net" + "net/http" + "strings" +) + +// JSON shorthand for json response +func JSON(w http.ResponseWriter, value interface{}) (int, error) { + data, err := json.Marshal(value) + if err != nil { + panic(err) + } + w.Header().Set("Content-Type", "application/json") + return w.Write(data) +} + +// MatchIPGlob match ip to glob pattern, ex. * 192.168.* 192.* 192.168.51.*2* +func MatchIPGlob(ip, pattern string) bool { + parts := strings.Split(pattern, ".") + seg := strings.Split(ip, ".") + for i, part := range parts { + + // normalize pattern to 3 digits + switch len(part) { + case 1: + if part == "*" { + part = "***" + } else { + part = "00" + part + } + case 2: + if strings.HasPrefix(part, "*") { + part = "*" + part + } else if strings.HasSuffix(part, "*") { + part = part + "*" + } else { + part = "0" + part + } + } + + // normalize ip to 3 digits + switch len(seg[i]) { + case 1: + seg[i] = "00" + seg[i] + case 2: + seg[i] = "0" + seg[i] + } + + for j := range part { + if string(part[j]) == "*" { + continue + } + if part[j] != seg[i][j] { + return false + } + } + } + return true +} + +// GetIP gets the real ip (could still be tricked by proxy) +func GetIP(r *http.Request) string { + ip := r.Header.Get("X-Real-Ip") + if ip == "" { + ips := r.Header.Get("X-Forwarded-For") + ipArr := strings.Split(ips, ",") + ip = strings.Trim(ipArr[len(ipArr)-1], " ") + } + if ip == "" { + var err error + ip, _, err = net.SplitHostPort(r.RemoteAddr) + if err != nil { + ip = r.RemoteAddr + } + } + return ip +} diff --git a/util.go b/util.go index 3d718e4..10b6bb0 100644 --- a/util.go +++ b/util.go @@ -1,80 +1,17 @@ package guard import ( - "encoding/json" - "net" "net/http" - "strings" + "path/filepath" ) -// JSON shorthand for json response -func JSON(w http.ResponseWriter, value interface{}) (int, error) { - data, err := json.Marshal(value) - if err != nil { - panic(err) +func SkipStatic(r *http.Request) bool { + switch filepath.Ext(r.URL.Path) { + case ".js": + fallthrough + case ".css": + return true + default: + return false } - w.Header().Set("Content-Type", "application/json") - return w.Write(data) -} - -// MatchIPGlob match ip to glob pattern, ex. * 192.168.* 192.* 192.168.51.*2* -func MatchIPGlob(ip, pattern string) bool { - parts := strings.Split(pattern, ".") - seg := strings.Split(ip, ".") - for i, part := range parts { - - // normalize pattern to 3 digits - switch len(part) { - case 1: - if part == "*" { - part = "***" - } else { - part = "00" + part - } - case 2: - if strings.HasPrefix(part, "*") { - part = "*" + part - } else if strings.HasSuffix(part, "*") { - part = part + "*" - } else { - part = "0" + part - } - } - - // normalize ip to 3 digits - switch len(seg[i]) { - case 1: - seg[i] = "00" + seg[i] - case 2: - seg[i] = "0" + seg[i] - } - - for j := range part { - if string(part[j]) == "*" { - continue - } - if part[j] != seg[i][j] { - return false - } - } - } - return true -} - -// GetIP gets the real ip (could still be tricked by proxy) -func GetIP(r *http.Request) string { - ip := r.Header.Get("X-Real-Ip") - if ip == "" { - ips := r.Header.Get("X-Forwarded-For") - ipArr := strings.Split(ips, ",") - ip = strings.Trim(ipArr[len(ipArr)-1], " ") - } - if ip == "" { - var err error - ip, _, err = net.SplitHostPort(r.RemoteAddr) - if err != nil { - ip = r.RemoteAddr - } - } - return ip }