guard/guard.go

129 lines
2.7 KiB
Go

package guard
import (
"bufio"
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"kumoly.io/lib/guard/netutil"
)
type Guard struct {
AllowIPGlob string `json:"allow_ip_glob"`
AllowIPNet *net.IPNet `json:"allow_ipnet"`
basicAuth string
Skip func(r *http.Request) bool
}
func New() *Guard {
return &Guard{}
}
func (g *Guard) SetBasicAuth(user, pass string) {
src := user + ":" + pass
g.basicAuth = "Basic " + base64.URLEncoding.EncodeToString([]byte(src))
}
type responseWriter struct {
http.ResponseWriter
StatueCode int
err string
}
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("hijack not supported")
}
return h.Hijack()
}
func (w *responseWriter) WriteHeader(statusCode int) {
if w.StatueCode != 0 {
return
}
w.StatueCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *responseWriter) Write(body []byte) (int, error) {
if w.StatueCode >= 500 {
w.err = string(body)
}
if w.StatueCode == 0 {
w.WriteHeader(200)
}
return w.ResponseWriter.Write(body)
}
func (g *Guard) Guard(next http.Handler) http.Handler {
l := log.With().Str("mod", "guard").Logger()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rw := &responseWriter{w, 0, ""}
ip := netutil.GetIP(r)
defer func() {
var cl *zerolog.Event
err := recover()
if err != nil {
cl = l.Error()
switch v := err.(type) {
case Error:
rw.WriteHeader(v.Code)
netutil.JSON(rw, v)
cl.Err(v)
case error:
rw.WriteHeader(500)
rw.Write([]byte(v.Error()))
cl.Err(v)
default:
rw.WriteHeader(500)
rw.Write([]byte(fmt.Sprint(err)))
cl.Str("error", fmt.Sprint(err))
}
} else if rw.StatueCode >= 500 {
cl = l.Error().Str("error", rw.err)
} else {
cl = l.Info()
}
if g.Skip != nil && g.Skip(r) {
return
}
cl.
Str("method", r.Method).
Str("ip", ip).
Int("status", rw.StatueCode).
Dur("duration", time.Since(start)).
Stringer("url", r.URL).
Msg("")
}()
// guard
if g.AllowIPNet != nil && !g.AllowIPNet.Contains(net.ParseIP(ip)) {
panic(ErrorForbidden)
}
if g.AllowIPGlob != "" && netutil.MatchIPGlob(ip, g.AllowIPGlob) {
panic(ErrorForbidden)
}
if g.basicAuth != "" {
auth := r.Header.Get("Authorization")
if auth != g.basicAuth {
rw.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
rw.WriteHeader(http.StatusUnauthorized)
rw.Write([]byte("Unauthorized"))
}
}
next.ServeHTTP(rw, r)
})
}
// func