commit c5073d7a74c7c50a121d1e1e7219920e568798ac Author: Evan Chen Date: Fri Nov 19 14:35:53 2021 +0800 update diff --git a/error.go b/error.go new file mode 100644 index 0000000..5a9cde3 --- /dev/null +++ b/error.go @@ -0,0 +1,38 @@ +package guard + +import ( + "encoding/json" + "fmt" + "net/http" +) + +type Error struct { + Code int `json:"code"` + ID string `json:"id"` + Message string `json:"msg"` + Tmpl string `json:"-"` +} + +func (e Error) New(v ...interface{}) Error { + e.Message = fmt.Sprintf(e.Tmpl, v...) + return e +} + +func (e Error) Error() string { + return e.Message +} + +func (e Error) String() string { + return e.Message +} + +func (e Error) Json() []byte { + data, _ := json.Marshal(e) + return data +} + +var ErrorForbidden = Error{ + Code: http.StatusForbidden, + ID: "ErrorForbidden", + Message: "permission denied", +} diff --git a/example/basicauth/main.go b/example/basicauth/main.go new file mode 100644 index 0000000..c25e7dc --- /dev/null +++ b/example/basicauth/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "net/http" + + "kumoly.io/lib/guard" +) + +func main() { + mux := http.NewServeMux() + g := guard.New() + g.SetBasicAuth("evan", "evan") + mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { + rw.Write([]byte("ok")) + }) + http.ListenAndServe("127.0.0.1:8000", g.Guard(mux)) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..614e1ad --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module kumoly.io/lib/guard + +go 1.17 + +require github.com/rs/zerolog v1.26.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ff1c250 --- /dev/null +++ b/go.sum @@ -0,0 +1,29 @@ +github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.26.0 h1:ORM4ibhEZeTeQlCojCK2kPz1ogAY4bGs4tD+SaAdGaE= +github.com/rs/zerolog v1.26.0/go.mod h1:yBiM87lvSqX8h0Ww4sdzNSkVYZ8dL2xjZJG1lAuGZEo= +github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/guard.go b/guard.go new file mode 100644 index 0000000..db6fb4a --- /dev/null +++ b/guard.go @@ -0,0 +1,113 @@ +package guard + +import ( + "encoding/base64" + "fmt" + "net" + "net/http" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +type Guard struct { + AllowIPGlob string `json:"allow_ip_glob"` + AllowIPNet *net.IPNet `json:"allow_ipnet"` + + basicAuth string +} + +func New() *Guard { + return &Guard{} +} + +func (g *Guard) SetBasicAuth(user, pass string) { + src := user + ":" + pass + g.basicAuth = "Basic " + base64.URLEncoding.EncodeToString([]byte(src)) +} + +type responseWriter struct { + http.ResponseWriter + StatueCode int + err string +} + +func (w *responseWriter) WriteHeader(statusCode int) { + if w.StatueCode != 0 { + return + } + w.StatueCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *responseWriter) Write(body []byte) (int, error) { + if w.StatueCode >= 500 { + w.err = string(body) + } + if w.StatueCode == 0 { + w.WriteHeader(200) + } + return w.ResponseWriter.Write(body) +} + +func (g *Guard) Guard(next http.Handler) http.Handler { + l := log.With().Str("mod", "guard").Logger() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + rw := &responseWriter{w, 0, ""} + ip := GetIP(r) + defer func() { + var cl *zerolog.Event + err := recover() + if err != nil { + cl = l.Error() + switch v := err.(type) { + case Error: + rw.WriteHeader(v.Code) + JSON(rw, v) + cl.Err(v) + case error: + rw.WriteHeader(500) + rw.Write([]byte(v.Error())) + cl.Err(v) + default: + rw.WriteHeader(500) + rw.Write([]byte(fmt.Sprint(err))) + cl.Str("error", fmt.Sprint(err)) + } + } else if rw.StatueCode >= 500 { + cl = l.Error().Str("error", rw.err) + } else { + cl = l.Info() + } + cl. + Str("method", r.Method). + Str("ip", ip). + Int("status", rw.StatueCode). + Dur("duration", time.Since(start)). + Stringer("url", r.URL). + Msg("") + }() + + // guard + if g.AllowIPNet != nil && !g.AllowIPNet.Contains(net.ParseIP(ip)) { + panic(ErrorForbidden) + } + if g.AllowIPGlob != "" && MatchIPGlob(ip, g.AllowIPGlob) { + panic(ErrorForbidden) + } + if g.basicAuth != "" { + auth := r.Header.Get("Authorization") + if auth != g.basicAuth { + rw.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) + rw.WriteHeader(http.StatusUnauthorized) + rw.Write([]byte("Unauthorized")) + } + } + + next.ServeHTTP(rw, r) + }) +} + +// func diff --git a/guard_test.go b/guard_test.go new file mode 100644 index 0000000..fe27835 --- /dev/null +++ b/guard_test.go @@ -0,0 +1,32 @@ +package guard + +import ( + "errors" + "net/http" + "testing" +) + +func TestGuard(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/panic", func(rw http.ResponseWriter, r *http.Request) { + panic(nil) + }) + mux.HandleFunc("/custerr", func(rw http.ResponseWriter, r *http.Request) { + panic(Error{ + Code: 404, + Message: "custerr", + ID: "test", + }) + }) + mux.HandleFunc("/err", func(rw http.ResponseWriter, r *http.Request) { + panic(errors.New("err")) + }) + mux.HandleFunc("/500", func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(500) + rw.Write([]byte("err")) + }) + mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { + rw.Write([]byte("ok")) + }) + http.ListenAndServe(":8000", New().Guard(mux)) +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..3d718e4 --- /dev/null +++ b/util.go @@ -0,0 +1,80 @@ +package guard + +import ( + "encoding/json" + "net" + "net/http" + "strings" +) + +// JSON shorthand for json response +func JSON(w http.ResponseWriter, value interface{}) (int, error) { + data, err := json.Marshal(value) + if err != nil { + panic(err) + } + w.Header().Set("Content-Type", "application/json") + return w.Write(data) +} + +// MatchIPGlob match ip to glob pattern, ex. * 192.168.* 192.* 192.168.51.*2* +func MatchIPGlob(ip, pattern string) bool { + parts := strings.Split(pattern, ".") + seg := strings.Split(ip, ".") + for i, part := range parts { + + // normalize pattern to 3 digits + switch len(part) { + case 1: + if part == "*" { + part = "***" + } else { + part = "00" + part + } + case 2: + if strings.HasPrefix(part, "*") { + part = "*" + part + } else if strings.HasSuffix(part, "*") { + part = part + "*" + } else { + part = "0" + part + } + } + + // normalize ip to 3 digits + switch len(seg[i]) { + case 1: + seg[i] = "00" + seg[i] + case 2: + seg[i] = "0" + seg[i] + } + + for j := range part { + if string(part[j]) == "*" { + continue + } + if part[j] != seg[i][j] { + return false + } + } + } + return true +} + +// GetIP gets the real ip (could still be tricked by proxy) +func GetIP(r *http.Request) string { + ip := r.Header.Get("X-Real-Ip") + if ip == "" { + ips := r.Header.Get("X-Forwarded-For") + ipArr := strings.Split(ips, ",") + ip = strings.Trim(ipArr[len(ipArr)-1], " ") + } + if ip == "" { + var err error + ip, _, err = net.SplitHostPort(r.RemoteAddr) + if err != nil { + ip = r.RemoteAddr + } + } + return ip +}