package guard import ( "bufio" "encoding/base64" "errors" "fmt" "net" "net/http" "time" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "kumoly.io/lib/guard/netutil" ) type Guard struct { AllowIPGlob string `json:"allow_ip_glob"` AllowIPNet *net.IPNet `json:"allow_ipnet"` basicAuth string Skip func(r *http.Request) bool } func New() *Guard { return &Guard{} } func (g *Guard) SetBasicAuth(user, pass string) { src := user + ":" + pass g.basicAuth = "Basic " + base64.URLEncoding.EncodeToString([]byte(src)) } type responseWriter struct { http.ResponseWriter StatueCode int err string } func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { h, ok := w.ResponseWriter.(http.Hijacker) if !ok { return nil, nil, errors.New("hijack not supported") } return h.Hijack() } func (w *responseWriter) WriteHeader(statusCode int) { if w.StatueCode != 0 { return } w.StatueCode = statusCode w.ResponseWriter.WriteHeader(statusCode) } func (w *responseWriter) Write(body []byte) (int, error) { if w.StatueCode >= 500 { w.err = string(body) } if w.StatueCode == 0 { w.WriteHeader(200) } return w.ResponseWriter.Write(body) } func (g *Guard) Guard(next http.Handler) http.Handler { l := log.With().Str("mod", "guard").Logger() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() rw := &responseWriter{w, 0, ""} ip := netutil.GetIP(r) defer func() { var cl *zerolog.Event err := recover() if err != nil { cl = l.Error() switch v := err.(type) { case Error: rw.WriteHeader(v.Code) netutil.JSON(rw, v) cl.Err(v) case error: rw.WriteHeader(500) rw.Write([]byte(v.Error())) cl.Err(v) default: rw.WriteHeader(500) rw.Write([]byte(fmt.Sprint(err))) cl.Str("error", fmt.Sprint(err)) } } else if rw.StatueCode >= 500 { cl = l.Error().Str("error", rw.err) } else { cl = l.Info() } if g.Skip != nil && g.Skip(r) { return } cl. Str("method", r.Method). Str("ip", ip). Int("status", rw.StatueCode). Dur("duration", time.Since(start)). Stringer("url", r.URL). Msg("") }() // guard if g.AllowIPNet != nil && !g.AllowIPNet.Contains(net.ParseIP(ip)) { panic(ErrorForbidden) } if g.AllowIPGlob != "" && netutil.MatchIPGlob(ip, g.AllowIPGlob) { panic(ErrorForbidden) } if g.basicAuth != "" { auth := r.Header.Get("Authorization") if auth != g.basicAuth { rw.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) rw.WriteHeader(http.StatusUnauthorized) rw.Write([]byte("Unauthorized")) } } next.ServeHTTP(rw, r) }) } // func