master
Evan Chen 2021-11-19 14:35:53 +08:00
commit c5073d7a74
7 changed files with 314 additions and 0 deletions

38
error.go Normal file
View File

@ -0,0 +1,38 @@
package guard
import (
"encoding/json"
"fmt"
"net/http"
)
type Error struct {
Code int `json:"code"`
ID string `json:"id"`
Message string `json:"msg"`
Tmpl string `json:"-"`
}
func (e Error) New(v ...interface{}) Error {
e.Message = fmt.Sprintf(e.Tmpl, v...)
return e
}
func (e Error) Error() string {
return e.Message
}
func (e Error) String() string {
return e.Message
}
func (e Error) Json() []byte {
data, _ := json.Marshal(e)
return data
}
var ErrorForbidden = Error{
Code: http.StatusForbidden,
ID: "ErrorForbidden",
Message: "permission denied",
}

17
example/basicauth/main.go Normal file
View File

@ -0,0 +1,17 @@
package main
import (
"net/http"
"kumoly.io/lib/guard"
)
func main() {
mux := http.NewServeMux()
g := guard.New()
g.SetBasicAuth("evan", "evan")
mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("ok"))
})
http.ListenAndServe("127.0.0.1:8000", g.Guard(mux))
}

5
go.mod Normal file
View File

@ -0,0 +1,5 @@
module kumoly.io/lib/guard
go 1.17
require github.com/rs/zerolog v1.26.0

29
go.sum Normal file
View File

@ -0,0 +1,29 @@
github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.26.0 h1:ORM4ibhEZeTeQlCojCK2kPz1ogAY4bGs4tD+SaAdGaE=
github.com/rs/zerolog v1.26.0/go.mod h1:yBiM87lvSqX8h0Ww4sdzNSkVYZ8dL2xjZJG1lAuGZEo=
github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

113
guard.go Normal file
View File

@ -0,0 +1,113 @@
package guard
import (
"encoding/base64"
"fmt"
"net"
"net/http"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
type Guard struct {
AllowIPGlob string `json:"allow_ip_glob"`
AllowIPNet *net.IPNet `json:"allow_ipnet"`
basicAuth string
}
func New() *Guard {
return &Guard{}
}
func (g *Guard) SetBasicAuth(user, pass string) {
src := user + ":" + pass
g.basicAuth = "Basic " + base64.URLEncoding.EncodeToString([]byte(src))
}
type responseWriter struct {
http.ResponseWriter
StatueCode int
err string
}
func (w *responseWriter) WriteHeader(statusCode int) {
if w.StatueCode != 0 {
return
}
w.StatueCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *responseWriter) Write(body []byte) (int, error) {
if w.StatueCode >= 500 {
w.err = string(body)
}
if w.StatueCode == 0 {
w.WriteHeader(200)
}
return w.ResponseWriter.Write(body)
}
func (g *Guard) Guard(next http.Handler) http.Handler {
l := log.With().Str("mod", "guard").Logger()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rw := &responseWriter{w, 0, ""}
ip := GetIP(r)
defer func() {
var cl *zerolog.Event
err := recover()
if err != nil {
cl = l.Error()
switch v := err.(type) {
case Error:
rw.WriteHeader(v.Code)
JSON(rw, v)
cl.Err(v)
case error:
rw.WriteHeader(500)
rw.Write([]byte(v.Error()))
cl.Err(v)
default:
rw.WriteHeader(500)
rw.Write([]byte(fmt.Sprint(err)))
cl.Str("error", fmt.Sprint(err))
}
} else if rw.StatueCode >= 500 {
cl = l.Error().Str("error", rw.err)
} else {
cl = l.Info()
}
cl.
Str("method", r.Method).
Str("ip", ip).
Int("status", rw.StatueCode).
Dur("duration", time.Since(start)).
Stringer("url", r.URL).
Msg("")
}()
// guard
if g.AllowIPNet != nil && !g.AllowIPNet.Contains(net.ParseIP(ip)) {
panic(ErrorForbidden)
}
if g.AllowIPGlob != "" && MatchIPGlob(ip, g.AllowIPGlob) {
panic(ErrorForbidden)
}
if g.basicAuth != "" {
auth := r.Header.Get("Authorization")
if auth != g.basicAuth {
rw.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
rw.WriteHeader(http.StatusUnauthorized)
rw.Write([]byte("Unauthorized"))
}
}
next.ServeHTTP(rw, r)
})
}
// func

32
guard_test.go Normal file
View File

@ -0,0 +1,32 @@
package guard
import (
"errors"
"net/http"
"testing"
)
func TestGuard(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/panic", func(rw http.ResponseWriter, r *http.Request) {
panic(nil)
})
mux.HandleFunc("/custerr", func(rw http.ResponseWriter, r *http.Request) {
panic(Error{
Code: 404,
Message: "custerr",
ID: "test",
})
})
mux.HandleFunc("/err", func(rw http.ResponseWriter, r *http.Request) {
panic(errors.New("err"))
})
mux.HandleFunc("/500", func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(500)
rw.Write([]byte("err"))
})
mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("ok"))
})
http.ListenAndServe(":8000", New().Guard(mux))
}

80
util.go Normal file
View File

@ -0,0 +1,80 @@
package guard
import (
"encoding/json"
"net"
"net/http"
"strings"
)
// JSON shorthand for json response
func JSON(w http.ResponseWriter, value interface{}) (int, error) {
data, err := json.Marshal(value)
if err != nil {
panic(err)
}
w.Header().Set("Content-Type", "application/json")
return w.Write(data)
}
// MatchIPGlob match ip to glob pattern, ex. * 192.168.* 192.* 192.168.51.*2*
func MatchIPGlob(ip, pattern string) bool {
parts := strings.Split(pattern, ".")
seg := strings.Split(ip, ".")
for i, part := range parts {
// normalize pattern to 3 digits
switch len(part) {
case 1:
if part == "*" {
part = "***"
} else {
part = "00" + part
}
case 2:
if strings.HasPrefix(part, "*") {
part = "*" + part
} else if strings.HasSuffix(part, "*") {
part = part + "*"
} else {
part = "0" + part
}
}
// normalize ip to 3 digits
switch len(seg[i]) {
case 1:
seg[i] = "00" + seg[i]
case 2:
seg[i] = "0" + seg[i]
}
for j := range part {
if string(part[j]) == "*" {
continue
}
if part[j] != seg[i][j] {
return false
}
}
}
return true
}
// GetIP gets the real ip (could still be tricked by proxy)
func GetIP(r *http.Request) string {
ip := r.Header.Get("X-Real-Ip")
if ip == "" {
ips := r.Header.Get("X-Forwarded-For")
ipArr := strings.Split(ips, ",")
ip = strings.Trim(ipArr[len(ipArr)-1], " ")
}
if ip == "" {
var err error
ip, _, err = net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip = r.RemoteAddr
}
}
return ip
}