124 lines
2.5 KiB
Go
124 lines
2.5 KiB
Go
package guard
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
type Guard struct {
|
|
AllowIPGlob string `json:"allow_ip_glob"`
|
|
AllowIPNet *net.IPNet `json:"allow_ipnet"`
|
|
|
|
basicAuth string
|
|
}
|
|
|
|
func New() *Guard {
|
|
return &Guard{}
|
|
}
|
|
|
|
func (g *Guard) SetBasicAuth(user, pass string) {
|
|
src := user + ":" + pass
|
|
g.basicAuth = "Basic " + base64.URLEncoding.EncodeToString([]byte(src))
|
|
}
|
|
|
|
type responseWriter struct {
|
|
http.ResponseWriter
|
|
StatueCode int
|
|
err string
|
|
}
|
|
|
|
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
h, ok := w.ResponseWriter.(http.Hijacker)
|
|
if !ok {
|
|
return nil, nil, errors.New("hijack not supported")
|
|
}
|
|
return h.Hijack()
|
|
}
|
|
|
|
func (w *responseWriter) WriteHeader(statusCode int) {
|
|
if w.StatueCode != 0 {
|
|
return
|
|
}
|
|
w.StatueCode = statusCode
|
|
w.ResponseWriter.WriteHeader(statusCode)
|
|
}
|
|
|
|
func (w *responseWriter) Write(body []byte) (int, error) {
|
|
if w.StatueCode >= 500 {
|
|
w.err = string(body)
|
|
}
|
|
if w.StatueCode == 0 {
|
|
w.WriteHeader(200)
|
|
}
|
|
return w.ResponseWriter.Write(body)
|
|
}
|
|
|
|
func (g *Guard) Guard(next http.Handler) http.Handler {
|
|
l := log.With().Str("mod", "guard").Logger()
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
rw := &responseWriter{w, 0, ""}
|
|
ip := GetIP(r)
|
|
defer func() {
|
|
var cl *zerolog.Event
|
|
err := recover()
|
|
if err != nil {
|
|
cl = l.Error()
|
|
switch v := err.(type) {
|
|
case Error:
|
|
rw.WriteHeader(v.Code)
|
|
JSON(rw, v)
|
|
cl.Err(v)
|
|
case error:
|
|
rw.WriteHeader(500)
|
|
rw.Write([]byte(v.Error()))
|
|
cl.Err(v)
|
|
default:
|
|
rw.WriteHeader(500)
|
|
rw.Write([]byte(fmt.Sprint(err)))
|
|
cl.Str("error", fmt.Sprint(err))
|
|
}
|
|
} else if rw.StatueCode >= 500 {
|
|
cl = l.Error().Str("error", rw.err)
|
|
} else {
|
|
cl = l.Info()
|
|
}
|
|
cl.
|
|
Str("method", r.Method).
|
|
Str("ip", ip).
|
|
Int("status", rw.StatueCode).
|
|
Dur("duration", time.Since(start)).
|
|
Stringer("url", r.URL).
|
|
Msg("")
|
|
}()
|
|
|
|
// guard
|
|
if g.AllowIPNet != nil && !g.AllowIPNet.Contains(net.ParseIP(ip)) {
|
|
panic(ErrorForbidden)
|
|
}
|
|
if g.AllowIPGlob != "" && MatchIPGlob(ip, g.AllowIPGlob) {
|
|
panic(ErrorForbidden)
|
|
}
|
|
if g.basicAuth != "" {
|
|
auth := r.Header.Get("Authorization")
|
|
if auth != g.basicAuth {
|
|
rw.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
|
rw.WriteHeader(http.StatusUnauthorized)
|
|
rw.Write([]byte("Unauthorized"))
|
|
}
|
|
}
|
|
|
|
next.ServeHTTP(rw, r)
|
|
})
|
|
}
|
|
|
|
// func
|