package breacher import ( "fmt" "io" "io/ioutil" "log" "net" "os" "strconv" "strings" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) var ( sPasswd string sKey string ) func init() { sshCmd.Flags().StringVarP(&sPasswd, "password", "p", "", "login using password") sshCmd.Flags().StringVarP(&sKey, "keyfile", "i", "", "login keyfile (path/to/file)") } var sshCmd = &cobra.Command{ Use: "tunnel [from address] [to address] [user@host:port]", Short: "ssh tunneling to access remote services", Long: `ssh tunneling to access remote services ex. breacher tunnel :8080 host:80 user@example.com -p paswd breacher tunnel :8080 :80 user@example.com -i ~/.ssh/id_rsa breacher tunnel :8080 kumoly.io:443 user@example.com `, Args: cobra.ExactArgs(3), Run: func(cmd *cobra.Command, args []string) { localHost, localPortStr, err := net.SplitHostPort(args[0]) if err != nil { log.Fatalln(err) } remoteHost, remotePortStr, err := net.SplitHostPort(args[1]) if err != nil { log.Fatalln(err) } localPort, err := strconv.Atoi(localPortStr) if err != nil { log.Fatalln(err) } remotePort, err := strconv.Atoi(remotePortStr) if err != nil { log.Fatalln(err) } if localHost == "" { localHost = "0.0.0.0" } if remoteHost == "" { remoteHost = "localhost" } split := strings.Split(args[2], "@") if len(split) != 2 { log.Fatalln("ssh host name not valid") } usr := split[0] sshHost := "localhost" sshPort := 22 if !strings.Contains(split[1], ":") { sshHost = split[1] } else { sshPortStr := "" sshHost, sshPortStr, err = net.SplitHostPort(split[1]) if err != nil { log.Fatalln(err) } if sshHost == "" { log.Fatalln("no ssh host") } if sshPortStr != "" { sshPort, err = strconv.Atoi(sshPortStr) if err != nil { log.Fatalln(err) } } } var auth ssh.AuthMethod if sPasswd != "" { auth = ssh.Password(sPasswd) } else if sKey != "" { auth = PrivateKeyFile(sKey) } else { auth = SSHAgent() } st := NewSSHTunnel( &Endpoint{localHost, localPort, ""}, &Endpoint{remoteHost, remotePort, ""}, &Endpoint{sshHost, sshPort, usr}, auth, ) log.Fatalln(st.Start()) }, } type Endpoint struct { Host string Port int User string } func (endpoint *Endpoint) String() string { return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port) } type SSHTunnel struct { Local *Endpoint Server *Endpoint Remote *Endpoint Config *ssh.ClientConfig Conns []net.Conn SvrConns []*ssh.Client isOpen bool close chan interface{} } func newConnectionWaiter(listener net.Listener, c chan net.Conn) { conn, err := listener.Accept() if err != nil { fmt.Println(err) return } c <- conn } func (tunnel *SSHTunnel) Start() error { listener, err := net.Listen("tcp", tunnel.Local.String()) if err != nil { return err } tunnel.isOpen = true tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port for { if !tunnel.isOpen { break } c := make(chan net.Conn) go newConnectionWaiter(listener, c) log.Println("listening for new connections...") select { case <-tunnel.close: log.Println("close signal received, closing...") tunnel.isOpen = false case conn := <-c: tunnel.Conns = append(tunnel.Conns, conn) log.Println("accepted connection") go tunnel.forward(conn) } } var total int total = len(tunnel.Conns) for i, conn := range tunnel.Conns { log.Printf("closing the netConn (%d of %d)\n", i+1, total) err := conn.Close() if err != nil { log.Println(err.Error()) } } total = len(tunnel.SvrConns) for i, conn := range tunnel.SvrConns { log.Printf("closing the serverConn (%d of %d)\n", i+1, total) err := conn.Close() if err != nil { log.Println(err.Error()) } } err = listener.Close() if err != nil { return err } log.Println("tunnel closed") return nil } func (tunnel *SSHTunnel) forward(localConn net.Conn) { serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config) if err != nil { log.Printf("server dial error: %s\n", err) return } log.Printf("connected to %s (1 of 2)\n", tunnel.Server.String()) tunnel.SvrConns = append(tunnel.SvrConns, serverConn) remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String()) if err != nil { log.Printf("remote dial error: %s\n", err) return } tunnel.Conns = append(tunnel.Conns, remoteConn) log.Printf("connected to %s (2 of 2)\n", tunnel.Remote.String()) copyConn := func(writer, reader net.Conn) { _, err := io.Copy(writer, reader) if err != nil { log.Printf("io.Copy error: %s\n", err) } } go copyConn(localConn, remoteConn) go copyConn(remoteConn, localConn) } func (tunnel *SSHTunnel) Close() { tunnel.close <- struct{}{} } // NewSSHTunnel creates a new single-use tunnel. Supplying "0" for localport will use a random port. func NewSSHTunnel(from, to, server *Endpoint, auth ssh.AuthMethod) *SSHTunnel { if server.Port == 0 { server.Port = 22 } sshTunnel := &SSHTunnel{ Config: &ssh.ClientConfig{ User: server.User, Auth: []ssh.AuthMethod{auth}, HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { // Always accept key. return nil }, }, Local: from, Server: server, Remote: to, close: make(chan interface{}), } return sshTunnel } func PrivateKeyFile(file string) ssh.AuthMethod { buffer, err := ioutil.ReadFile(file) if err != nil { return nil } key, err := ssh.ParsePrivateKey(buffer) if err != nil { return nil } return ssh.PublicKeys(key) } func SSHAgent() ssh.AuthMethod { if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) } return nil }