262 lines
5.7 KiB
Go
262 lines
5.7 KiB
Go
package breacher
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/spf13/cobra"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
)
|
|
|
|
var (
|
|
sPasswd string
|
|
sKey string
|
|
)
|
|
|
|
func init() {
|
|
sshCmd.Flags().StringVarP(&sPasswd, "password", "p", "", "login using password")
|
|
sshCmd.Flags().StringVarP(&sKey, "keyfile", "i", "", "login keyfile (path/to/file)")
|
|
}
|
|
|
|
var sshCmd = &cobra.Command{
|
|
Use: "tunnel [from address] [to address] [user@host:port]",
|
|
Short: "ssh tunneling to access remote services",
|
|
Long: `ssh tunneling to access remote services
|
|
ex.
|
|
breacher tunnel :8080 host:80 user@example.com -p paswd
|
|
breacher tunnel :8080 :80 user@example.com -i ~/.ssh/id_rsa
|
|
breacher tunnel :8080 kumoly.io:443 user@example.com
|
|
`,
|
|
Args: cobra.ExactArgs(3),
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
localHost, localPortStr, err := net.SplitHostPort(args[0])
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
remoteHost, remotePortStr, err := net.SplitHostPort(args[1])
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
localPort, err := strconv.Atoi(localPortStr)
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
remotePort, err := strconv.Atoi(remotePortStr)
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
if localHost == "" {
|
|
localHost = "0.0.0.0"
|
|
}
|
|
if remoteHost == "" {
|
|
remoteHost = "localhost"
|
|
}
|
|
split := strings.Split(args[2], "@")
|
|
if len(split) != 2 {
|
|
log.Fatalln("ssh host name not valid")
|
|
}
|
|
usr := split[0]
|
|
sshHost := "localhost"
|
|
sshPort := 22
|
|
if !strings.Contains(split[1], ":") {
|
|
sshHost = split[1]
|
|
} else {
|
|
sshPortStr := ""
|
|
sshHost, sshPortStr, err = net.SplitHostPort(split[1])
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
if sshHost == "" {
|
|
log.Fatalln("no ssh host")
|
|
}
|
|
if sshPortStr != "" {
|
|
sshPort, err = strconv.Atoi(sshPortStr)
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
}
|
|
}
|
|
var auth ssh.AuthMethod
|
|
if sPasswd != "" {
|
|
auth = ssh.Password(sPasswd)
|
|
} else if sKey != "" {
|
|
auth = PrivateKeyFile(sKey)
|
|
} else {
|
|
auth = SSHAgent()
|
|
}
|
|
st := NewSSHTunnel(
|
|
&Endpoint{localHost, localPort, ""},
|
|
&Endpoint{remoteHost, remotePort, ""},
|
|
&Endpoint{sshHost, sshPort, usr},
|
|
auth,
|
|
)
|
|
|
|
log.Fatalln(st.Start())
|
|
},
|
|
}
|
|
|
|
type Endpoint struct {
|
|
Host string
|
|
Port int
|
|
User string
|
|
}
|
|
|
|
func (endpoint *Endpoint) String() string {
|
|
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
|
|
}
|
|
|
|
type SSHTunnel struct {
|
|
Local *Endpoint
|
|
Server *Endpoint
|
|
Remote *Endpoint
|
|
Config *ssh.ClientConfig
|
|
Conns []net.Conn
|
|
SvrConns []*ssh.Client
|
|
isOpen bool
|
|
close chan interface{}
|
|
}
|
|
|
|
func newConnectionWaiter(listener net.Listener, c chan net.Conn) {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
return
|
|
}
|
|
c <- conn
|
|
}
|
|
|
|
func (tunnel *SSHTunnel) Start() error {
|
|
listener, err := net.Listen("tcp", tunnel.Local.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tunnel.isOpen = true
|
|
tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port
|
|
|
|
for {
|
|
if !tunnel.isOpen {
|
|
break
|
|
}
|
|
|
|
c := make(chan net.Conn)
|
|
go newConnectionWaiter(listener, c)
|
|
log.Println("listening for new connections...")
|
|
|
|
select {
|
|
case <-tunnel.close:
|
|
log.Println("close signal received, closing...")
|
|
tunnel.isOpen = false
|
|
case conn := <-c:
|
|
tunnel.Conns = append(tunnel.Conns, conn)
|
|
log.Println("accepted connection")
|
|
go tunnel.forward(conn)
|
|
}
|
|
}
|
|
var total int
|
|
total = len(tunnel.Conns)
|
|
for i, conn := range tunnel.Conns {
|
|
log.Printf("closing the netConn (%d of %d)\n", i+1, total)
|
|
err := conn.Close()
|
|
if err != nil {
|
|
log.Println(err.Error())
|
|
}
|
|
}
|
|
total = len(tunnel.SvrConns)
|
|
for i, conn := range tunnel.SvrConns {
|
|
log.Printf("closing the serverConn (%d of %d)\n", i+1, total)
|
|
err := conn.Close()
|
|
if err != nil {
|
|
log.Println(err.Error())
|
|
}
|
|
}
|
|
err = listener.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Println("tunnel closed")
|
|
return nil
|
|
}
|
|
|
|
func (tunnel *SSHTunnel) forward(localConn net.Conn) {
|
|
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
|
|
if err != nil {
|
|
log.Printf("server dial error: %s\n", err)
|
|
return
|
|
}
|
|
log.Printf("connected to %s (1 of 2)\n", tunnel.Server.String())
|
|
tunnel.SvrConns = append(tunnel.SvrConns, serverConn)
|
|
|
|
remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
|
|
if err != nil {
|
|
log.Printf("remote dial error: %s\n", err)
|
|
return
|
|
}
|
|
tunnel.Conns = append(tunnel.Conns, remoteConn)
|
|
log.Printf("connected to %s (2 of 2)\n", tunnel.Remote.String())
|
|
copyConn := func(writer, reader net.Conn) {
|
|
_, err := io.Copy(writer, reader)
|
|
if err != nil {
|
|
log.Printf("io.Copy error: %s\n", err)
|
|
}
|
|
}
|
|
go copyConn(localConn, remoteConn)
|
|
go copyConn(remoteConn, localConn)
|
|
|
|
}
|
|
|
|
func (tunnel *SSHTunnel) Close() {
|
|
tunnel.close <- struct{}{}
|
|
}
|
|
|
|
// NewSSHTunnel creates a new single-use tunnel. Supplying "0" for localport will use a random port.
|
|
func NewSSHTunnel(from, to, server *Endpoint, auth ssh.AuthMethod) *SSHTunnel {
|
|
if server.Port == 0 {
|
|
server.Port = 22
|
|
}
|
|
|
|
sshTunnel := &SSHTunnel{
|
|
Config: &ssh.ClientConfig{
|
|
User: server.User,
|
|
Auth: []ssh.AuthMethod{auth},
|
|
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
// Always accept key.
|
|
return nil
|
|
},
|
|
},
|
|
Local: from,
|
|
Server: server,
|
|
Remote: to,
|
|
close: make(chan interface{}),
|
|
}
|
|
|
|
return sshTunnel
|
|
}
|
|
|
|
func PrivateKeyFile(file string) ssh.AuthMethod {
|
|
buffer, err := ioutil.ReadFile(file)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
key, err := ssh.ParsePrivateKey(buffer)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
return ssh.PublicKeys(key)
|
|
}
|
|
|
|
func SSHAgent() ssh.AuthMethod {
|
|
if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
|
|
return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
|
|
}
|
|
return nil
|
|
}
|