update
commit
c5073d7a74
|
@ -0,0 +1,38 @@
|
|||
package guard
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
ID string `json:"id"`
|
||||
Message string `json:"msg"`
|
||||
Tmpl string `json:"-"`
|
||||
}
|
||||
|
||||
func (e Error) New(v ...interface{}) Error {
|
||||
e.Message = fmt.Sprintf(e.Tmpl, v...)
|
||||
return e
|
||||
}
|
||||
|
||||
func (e Error) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func (e Error) String() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func (e Error) Json() []byte {
|
||||
data, _ := json.Marshal(e)
|
||||
return data
|
||||
}
|
||||
|
||||
var ErrorForbidden = Error{
|
||||
Code: http.StatusForbidden,
|
||||
ID: "ErrorForbidden",
|
||||
Message: "permission denied",
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"kumoly.io/lib/guard"
|
||||
)
|
||||
|
||||
func main() {
|
||||
mux := http.NewServeMux()
|
||||
g := guard.New()
|
||||
g.SetBasicAuth("evan", "evan")
|
||||
mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("ok"))
|
||||
})
|
||||
http.ListenAndServe("127.0.0.1:8000", g.Guard(mux))
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
module kumoly.io/lib/guard
|
||||
|
||||
go 1.17
|
||||
|
||||
require github.com/rs/zerolog v1.26.0
|
|
@ -0,0 +1,29 @@
|
|||
github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.26.0 h1:ORM4ibhEZeTeQlCojCK2kPz1ogAY4bGs4tD+SaAdGaE=
|
||||
github.com/rs/zerolog v1.26.0/go.mod h1:yBiM87lvSqX8h0Ww4sdzNSkVYZ8dL2xjZJG1lAuGZEo=
|
||||
github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
|
@ -0,0 +1,113 @@
|
|||
package guard
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Guard struct {
|
||||
AllowIPGlob string `json:"allow_ip_glob"`
|
||||
AllowIPNet *net.IPNet `json:"allow_ipnet"`
|
||||
|
||||
basicAuth string
|
||||
}
|
||||
|
||||
func New() *Guard {
|
||||
return &Guard{}
|
||||
}
|
||||
|
||||
func (g *Guard) SetBasicAuth(user, pass string) {
|
||||
src := user + ":" + pass
|
||||
g.basicAuth = "Basic " + base64.URLEncoding.EncodeToString([]byte(src))
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
StatueCode int
|
||||
err string
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteHeader(statusCode int) {
|
||||
if w.StatueCode != 0 {
|
||||
return
|
||||
}
|
||||
w.StatueCode = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *responseWriter) Write(body []byte) (int, error) {
|
||||
if w.StatueCode >= 500 {
|
||||
w.err = string(body)
|
||||
}
|
||||
if w.StatueCode == 0 {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
return w.ResponseWriter.Write(body)
|
||||
}
|
||||
|
||||
func (g *Guard) Guard(next http.Handler) http.Handler {
|
||||
l := log.With().Str("mod", "guard").Logger()
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
rw := &responseWriter{w, 0, ""}
|
||||
ip := GetIP(r)
|
||||
defer func() {
|
||||
var cl *zerolog.Event
|
||||
err := recover()
|
||||
if err != nil {
|
||||
cl = l.Error()
|
||||
switch v := err.(type) {
|
||||
case Error:
|
||||
rw.WriteHeader(v.Code)
|
||||
JSON(rw, v)
|
||||
cl.Err(v)
|
||||
case error:
|
||||
rw.WriteHeader(500)
|
||||
rw.Write([]byte(v.Error()))
|
||||
cl.Err(v)
|
||||
default:
|
||||
rw.WriteHeader(500)
|
||||
rw.Write([]byte(fmt.Sprint(err)))
|
||||
cl.Str("error", fmt.Sprint(err))
|
||||
}
|
||||
} else if rw.StatueCode >= 500 {
|
||||
cl = l.Error().Str("error", rw.err)
|
||||
} else {
|
||||
cl = l.Info()
|
||||
}
|
||||
cl.
|
||||
Str("method", r.Method).
|
||||
Str("ip", ip).
|
||||
Int("status", rw.StatueCode).
|
||||
Dur("duration", time.Since(start)).
|
||||
Stringer("url", r.URL).
|
||||
Msg("")
|
||||
}()
|
||||
|
||||
// guard
|
||||
if g.AllowIPNet != nil && !g.AllowIPNet.Contains(net.ParseIP(ip)) {
|
||||
panic(ErrorForbidden)
|
||||
}
|
||||
if g.AllowIPGlob != "" && MatchIPGlob(ip, g.AllowIPGlob) {
|
||||
panic(ErrorForbidden)
|
||||
}
|
||||
if g.basicAuth != "" {
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth != g.basicAuth {
|
||||
rw.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
rw.Write([]byte("Unauthorized"))
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
|
||||
// func
|
|
@ -0,0 +1,32 @@
|
|||
package guard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGuard(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/panic", func(rw http.ResponseWriter, r *http.Request) {
|
||||
panic(nil)
|
||||
})
|
||||
mux.HandleFunc("/custerr", func(rw http.ResponseWriter, r *http.Request) {
|
||||
panic(Error{
|
||||
Code: 404,
|
||||
Message: "custerr",
|
||||
ID: "test",
|
||||
})
|
||||
})
|
||||
mux.HandleFunc("/err", func(rw http.ResponseWriter, r *http.Request) {
|
||||
panic(errors.New("err"))
|
||||
})
|
||||
mux.HandleFunc("/500", func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(500)
|
||||
rw.Write([]byte("err"))
|
||||
})
|
||||
mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("ok"))
|
||||
})
|
||||
http.ListenAndServe(":8000", New().Guard(mux))
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
package guard
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// JSON shorthand for json response
|
||||
func JSON(w http.ResponseWriter, value interface{}) (int, error) {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return w.Write(data)
|
||||
}
|
||||
|
||||
// MatchIPGlob match ip to glob pattern, ex. * 192.168.* 192.* 192.168.51.*2*
|
||||
func MatchIPGlob(ip, pattern string) bool {
|
||||
parts := strings.Split(pattern, ".")
|
||||
seg := strings.Split(ip, ".")
|
||||
for i, part := range parts {
|
||||
|
||||
// normalize pattern to 3 digits
|
||||
switch len(part) {
|
||||
case 1:
|
||||
if part == "*" {
|
||||
part = "***"
|
||||
} else {
|
||||
part = "00" + part
|
||||
}
|
||||
case 2:
|
||||
if strings.HasPrefix(part, "*") {
|
||||
part = "*" + part
|
||||
} else if strings.HasSuffix(part, "*") {
|
||||
part = part + "*"
|
||||
} else {
|
||||
part = "0" + part
|
||||
}
|
||||
}
|
||||
|
||||
// normalize ip to 3 digits
|
||||
switch len(seg[i]) {
|
||||
case 1:
|
||||
seg[i] = "00" + seg[i]
|
||||
case 2:
|
||||
seg[i] = "0" + seg[i]
|
||||
}
|
||||
|
||||
for j := range part {
|
||||
if string(part[j]) == "*" {
|
||||
continue
|
||||
}
|
||||
if part[j] != seg[i][j] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// GetIP gets the real ip (could still be tricked by proxy)
|
||||
func GetIP(r *http.Request) string {
|
||||
ip := r.Header.Get("X-Real-Ip")
|
||||
if ip == "" {
|
||||
ips := r.Header.Get("X-Forwarded-For")
|
||||
ipArr := strings.Split(ips, ",")
|
||||
ip = strings.Trim(ipArr[len(ipArr)-1], " ")
|
||||
}
|
||||
if ip == "" {
|
||||
var err error
|
||||
ip, _, err = net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ip = r.RemoteAddr
|
||||
}
|
||||
}
|
||||
return ip
|
||||
}
|
Loading…
Reference in New Issue