breach/breacher/ssh.go

262 lines
5.7 KiB
Go

package breacher
import (
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"strconv"
"strings"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
var (
sPasswd string
sKey string
)
func init() {
sshCmd.Flags().StringVarP(&sPasswd, "password", "p", "", "login using password")
sshCmd.Flags().StringVarP(&sKey, "keyfile", "i", "", "login keyfile (path/to/file)")
}
var sshCmd = &cobra.Command{
Use: "tunnel [from address] [to address] [user@host:port]",
Short: "ssh tunneling to access remote services",
Long: `ssh tunneling to access remote services
ex.
breacher tunnel :8080 host:80 user@example.com -p paswd
breacher tunnel :8080 :80 user@example.com -i ~/.ssh/id_rsa
breacher tunnel :8080 kumoly.io:443 user@example.com
`,
Args: cobra.ExactArgs(3),
Run: func(cmd *cobra.Command, args []string) {
localHost, localPortStr, err := net.SplitHostPort(args[0])
if err != nil {
log.Fatalln(err)
}
remoteHost, remotePortStr, err := net.SplitHostPort(args[1])
if err != nil {
log.Fatalln(err)
}
localPort, err := strconv.Atoi(localPortStr)
if err != nil {
log.Fatalln(err)
}
remotePort, err := strconv.Atoi(remotePortStr)
if err != nil {
log.Fatalln(err)
}
if localHost == "" {
localHost = "0.0.0.0"
}
if remoteHost == "" {
remoteHost = "localhost"
}
split := strings.Split(args[2], "@")
if len(split) != 2 {
log.Fatalln("ssh host name not valid")
}
usr := split[0]
sshHost := "localhost"
sshPort := 22
if !strings.Contains(split[1], ":") {
sshHost = split[1]
} else {
sshPortStr := ""
sshHost, sshPortStr, err = net.SplitHostPort(split[1])
if err != nil {
log.Fatalln(err)
}
if sshHost == "" {
log.Fatalln("no ssh host")
}
if sshPortStr != "" {
sshPort, err = strconv.Atoi(sshPortStr)
if err != nil {
log.Fatalln(err)
}
}
}
var auth ssh.AuthMethod
if sPasswd != "" {
auth = ssh.Password(sPasswd)
} else if sKey != "" {
auth = PrivateKeyFile(sKey)
} else {
auth = SSHAgent()
}
st := NewSSHTunnel(
&Endpoint{localHost, localPort, ""},
&Endpoint{remoteHost, remotePort, ""},
&Endpoint{sshHost, sshPort, usr},
auth,
)
log.Fatalln(st.Start())
},
}
type Endpoint struct {
Host string
Port int
User string
}
func (endpoint *Endpoint) String() string {
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}
type SSHTunnel struct {
Local *Endpoint
Server *Endpoint
Remote *Endpoint
Config *ssh.ClientConfig
Conns []net.Conn
SvrConns []*ssh.Client
isOpen bool
close chan interface{}
}
func newConnectionWaiter(listener net.Listener, c chan net.Conn) {
conn, err := listener.Accept()
if err != nil {
fmt.Println(err)
return
}
c <- conn
}
func (tunnel *SSHTunnel) Start() error {
listener, err := net.Listen("tcp", tunnel.Local.String())
if err != nil {
return err
}
tunnel.isOpen = true
tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port
for {
if !tunnel.isOpen {
break
}
c := make(chan net.Conn)
go newConnectionWaiter(listener, c)
log.Println("listening for new connections...")
select {
case <-tunnel.close:
log.Println("close signal received, closing...")
tunnel.isOpen = false
case conn := <-c:
tunnel.Conns = append(tunnel.Conns, conn)
log.Println("accepted connection")
go tunnel.forward(conn)
}
}
var total int
total = len(tunnel.Conns)
for i, conn := range tunnel.Conns {
log.Printf("closing the netConn (%d of %d)\n", i+1, total)
err := conn.Close()
if err != nil {
log.Println(err.Error())
}
}
total = len(tunnel.SvrConns)
for i, conn := range tunnel.SvrConns {
log.Printf("closing the serverConn (%d of %d)\n", i+1, total)
err := conn.Close()
if err != nil {
log.Println(err.Error())
}
}
err = listener.Close()
if err != nil {
return err
}
log.Println("tunnel closed")
return nil
}
func (tunnel *SSHTunnel) forward(localConn net.Conn) {
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
log.Printf("server dial error: %s\n", err)
return
}
log.Printf("connected to %s (1 of 2)\n", tunnel.Server.String())
tunnel.SvrConns = append(tunnel.SvrConns, serverConn)
remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
if err != nil {
log.Printf("remote dial error: %s\n", err)
return
}
tunnel.Conns = append(tunnel.Conns, remoteConn)
log.Printf("connected to %s (2 of 2)\n", tunnel.Remote.String())
copyConn := func(writer, reader net.Conn) {
_, err := io.Copy(writer, reader)
if err != nil {
log.Printf("io.Copy error: %s\n", err)
}
}
go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
}
func (tunnel *SSHTunnel) Close() {
tunnel.close <- struct{}{}
}
// NewSSHTunnel creates a new single-use tunnel. Supplying "0" for localport will use a random port.
func NewSSHTunnel(from, to, server *Endpoint, auth ssh.AuthMethod) *SSHTunnel {
if server.Port == 0 {
server.Port = 22
}
sshTunnel := &SSHTunnel{
Config: &ssh.ClientConfig{
User: server.User,
Auth: []ssh.AuthMethod{auth},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// Always accept key.
return nil
},
},
Local: from,
Server: server,
Remote: to,
close: make(chan interface{}),
}
return sshTunnel
}
func PrivateKeyFile(file string) ssh.AuthMethod {
buffer, err := ioutil.ReadFile(file)
if err != nil {
return nil
}
key, err := ssh.ParsePrivateKey(buffer)
if err != nil {
return nil
}
return ssh.PublicKeys(key)
}
func SSHAgent() ssh.AuthMethod {
if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
}
return nil
}