500 lines
14 KiB
Go
500 lines
14 KiB
Go
package breacher
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"io/ioutil"
|
|
"os/user"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
)
|
|
|
|
type Endpoint struct {
|
|
Host string
|
|
Port int
|
|
UnixSocket string
|
|
}
|
|
|
|
func (e *Endpoint) connectionString() string {
|
|
if e.UnixSocket != "" {
|
|
return e.UnixSocket
|
|
}
|
|
return fmt.Sprintf("%s:%d", e.Host, e.Port)
|
|
}
|
|
|
|
func (e *Endpoint) connectionType() string {
|
|
if e.UnixSocket != "" {
|
|
return "unix"
|
|
}
|
|
return "tcp"
|
|
}
|
|
|
|
// AuthType is the type of authentication to use for SSH.
|
|
type AuthType int
|
|
|
|
const (
|
|
// AuthTypeKeyFile uses the keys from a SSH key file read from the system.
|
|
AuthTypeKeyFile AuthType = iota
|
|
// AuthTypeEncryptedKeyFile uses the keys from an encrypted SSH key file read from the system.
|
|
AuthTypeEncryptedKeyFile
|
|
// AuthTypeKeyReader uses the keys from a SSH key reader.
|
|
AuthTypeKeyReader
|
|
// AuthTypeEncryptedKeyReader uses the keys from an encrypted SSH key reader.
|
|
AuthTypeEncryptedKeyReader
|
|
// AuthTypePassword uses a password directly.
|
|
AuthTypePassword
|
|
// AuthTypeSSHAgent will use registered users in the ssh-agent.
|
|
AuthTypeSSHAgent
|
|
// AuthTypeAuto tries to get the authentication method automatically. See SSHTun.Start for details on
|
|
// this.
|
|
AuthTypeAuto
|
|
)
|
|
|
|
// SSHTun represents a SSH tunnel
|
|
type SSHTun struct {
|
|
*sync.Mutex
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
errCh chan error
|
|
user string
|
|
authType AuthType
|
|
authKeyFile string
|
|
authKeyReader io.Reader
|
|
authPassword string
|
|
server Endpoint
|
|
local Endpoint
|
|
remote Endpoint
|
|
started bool
|
|
timeout time.Duration
|
|
debug bool
|
|
connState func(*SSHTun, ConnState)
|
|
}
|
|
|
|
// ConnState represents the state of the SSH tunnel. It's returned to an optional function provided to SetConnState.
|
|
type ConnState int
|
|
|
|
const (
|
|
// StateStopped represents a stopped tunnel. A call to Start will make the state to transition to StateStarting.
|
|
StateStopped ConnState = iota
|
|
|
|
// StateStarting represents a tunnel initializing and preparing to listen for connections.
|
|
// A successful initialization will make the state to transition to StateStarted, otherwise it will transition to StateStopped.
|
|
StateStarting
|
|
|
|
// StateStarted represents a tunnel ready to accept connections.
|
|
// A call to stop or an error will make the state to transition to StateStopped.
|
|
StateStarted
|
|
)
|
|
|
|
// New creates a new SSH tunnel to the specified server redirecting a port on local localhost to a port on remote localhost.
|
|
// By default the SSH connection is made to port 22 as root and using automatic detection of the authentication
|
|
// method (see Start for details on this).
|
|
// Calling SetPassword will change the authentication to password based.
|
|
// Calling SetKeyFile will change the authentication to keyfile based with an optional key file.
|
|
// The SSH user and port can be changed with SetUser and SetPort.
|
|
// The local and remote hosts can be changed to something different than localhost with SetLocalHost and SetRemoteHost.
|
|
// The states of the tunnel can be received throgh a callback function with SetConnState.
|
|
func NewSSHTunnel(localHost string, localPort int, remoteHost string, remotePort int) *SSHTun {
|
|
return &SSHTun{
|
|
Mutex: &sync.Mutex{},
|
|
server: Endpoint{
|
|
Host: "",
|
|
Port: 22,
|
|
},
|
|
user: "root",
|
|
authType: AuthTypeAuto,
|
|
authKeyFile: "",
|
|
authPassword: "",
|
|
local: Endpoint{
|
|
Host: localHost,
|
|
Port: localPort,
|
|
},
|
|
remote: Endpoint{
|
|
Host: remoteHost,
|
|
Port: remotePort,
|
|
},
|
|
started: false,
|
|
timeout: time.Second * 15,
|
|
debug: false,
|
|
}
|
|
}
|
|
|
|
func NewUnix(localUnixSocket string, server string, remoteUnixSocket string) *SSHTun {
|
|
return &SSHTun{
|
|
Mutex: &sync.Mutex{},
|
|
server: Endpoint{
|
|
Host: server,
|
|
Port: 22,
|
|
},
|
|
user: "root",
|
|
authType: AuthTypeAuto,
|
|
authKeyFile: "",
|
|
authPassword: "",
|
|
local: Endpoint{
|
|
UnixSocket: localUnixSocket,
|
|
},
|
|
remote: Endpoint{
|
|
UnixSocket: remoteUnixSocket,
|
|
},
|
|
started: false,
|
|
timeout: time.Second * 15,
|
|
debug: false,
|
|
}
|
|
}
|
|
|
|
// SetPort changes the port where the SSH connection will be made.
|
|
func (tun *SSHTun) SetPort(port int) {
|
|
tun.server.Port = port
|
|
}
|
|
|
|
// SetUser changes the user used to make the SSH connection.
|
|
func (tun *SSHTun) SetUser(user string) {
|
|
tun.user = user
|
|
}
|
|
|
|
// SetKeyFile changes the authentication to key-based and uses the specified file.
|
|
// Leaving it empty defaults to the default linux private key location ($HOME/.ssh/id_rsa).
|
|
func (tun *SSHTun) SetKeyFile(file string) {
|
|
tun.authType = AuthTypeKeyFile
|
|
tun.authKeyFile = file
|
|
}
|
|
|
|
// SetEncryptedKeyFile changes the authentication to encrypted key-based and uses the specified file and password.
|
|
// Leaving it empty defaults to the default linux private key location ($HOME/.ssh/id_rsa).
|
|
func (tun *SSHTun) SetEncryptedKeyFile(file string, password string) {
|
|
tun.authType = AuthTypeEncryptedKeyFile
|
|
tun.authKeyFile = file
|
|
tun.authPassword = password
|
|
}
|
|
|
|
// SetKeyReader changes the authentication to key-based and uses the specified reader.
|
|
// Leaving it empty defaults to the default linux private key location ($HOME/.ssh/id_rsa).
|
|
func (tun *SSHTun) SetKeyReader(reader io.Reader) {
|
|
tun.authType = AuthTypeKeyReader
|
|
tun.authKeyReader = reader
|
|
}
|
|
|
|
// SetEncryptedKeyReader changes the authentication to encrypted key-based and uses the specified reader and password.
|
|
// Leaving it empty defaults to the default linux private key location ($HOME/.ssh/id_rsa).
|
|
func (tun *SSHTun) SetEncryptedKeyReader(reader io.Reader, password string) {
|
|
tun.authType = AuthTypeEncryptedKeyReader
|
|
tun.authKeyReader = reader
|
|
tun.authPassword = password
|
|
}
|
|
|
|
// SetSSHAgent changes the authentication to ssh-agent.
|
|
func (tun *SSHTun) SetSSHAgent() {
|
|
tun.authType = AuthTypeSSHAgent
|
|
}
|
|
|
|
// SetPassword changes the authentication to password-based and uses the specified password.
|
|
func (tun *SSHTun) SetPassword(password string) {
|
|
tun.authType = AuthTypePassword
|
|
tun.authPassword = password
|
|
}
|
|
|
|
// SetLocalHost sets the local host to redirect (defaults to localhost)
|
|
func (tun *SSHTun) SetLocalHost(host string) {
|
|
tun.local.Host = host
|
|
}
|
|
|
|
// SetRemoteHost sets the remote host to redirect (defaults to localhost)
|
|
func (tun *SSHTun) SetRemoteHost(host string) {
|
|
tun.remote.Host = host
|
|
}
|
|
|
|
// SetTimeout sets the connection timeouts (defaults to 15 seconds).
|
|
func (tun *SSHTun) SetTimeout(timeout time.Duration) {
|
|
tun.timeout = timeout
|
|
}
|
|
|
|
// SetDebug enables or disables log messages (disabled by default).
|
|
func (tun *SSHTun) SetDebug(debug bool) {
|
|
tun.debug = debug
|
|
}
|
|
|
|
// SetConnState specifies an optional callback function that is called when a SSH tunnel changes state.
|
|
// See the ConnState type and associated constants for details.
|
|
func (tun *SSHTun) SetConnState(connStateFun func(*SSHTun, ConnState)) {
|
|
tun.connState = connStateFun
|
|
}
|
|
|
|
// Start starts the SSH tunnel. After this call, all Set* methods will have no effect until Close is called.
|
|
// Note on SSH authentication: in case the tunnel's authType is set to AuthTypeAuto the following will happen:
|
|
// The default key file will be used, if that doesn't succeed it will try to use the SSH agent.
|
|
// If that fails the whole authentication fails.
|
|
// That means if you want to use password or encrypted key file authentication, you have to specify that explicitly.
|
|
func (tun *SSHTun) Start() error {
|
|
tun.Lock()
|
|
|
|
if tun.connState != nil {
|
|
tun.connState(tun, StateStarting)
|
|
}
|
|
|
|
// SSH config
|
|
config, err := tun.initSSHConfig()
|
|
if err != nil {
|
|
return tun.errNotStarted(err)
|
|
}
|
|
|
|
local := tun.local.connectionString()
|
|
// Local listener
|
|
localList, err := net.Listen(tun.local.connectionType(), local)
|
|
if err != nil {
|
|
return tun.errNotStarted(fmt.Errorf("local listen on %s failed: %s", local, err.Error()))
|
|
}
|
|
|
|
// Context and error channel
|
|
tun.ctx, tun.cancel = context.WithCancel(context.Background())
|
|
tun.errCh = make(chan error)
|
|
|
|
// Accept connections
|
|
go func() {
|
|
for {
|
|
localConn, err := localList.Accept()
|
|
if err != nil {
|
|
tun.errStarted(fmt.Errorf("local accept on %s failed: %s", local, err.Error()))
|
|
break
|
|
}
|
|
if tun.debug {
|
|
log.Printf("Accepted connection from %s", localConn.RemoteAddr().String())
|
|
}
|
|
|
|
// Launch the forward
|
|
go tun.forward(localConn, config)
|
|
}
|
|
}()
|
|
|
|
// Wait until someone cancels the context and stop accepting connections
|
|
go func() {
|
|
<-tun.ctx.Done()
|
|
localList.Close()
|
|
}()
|
|
|
|
// Now others can call Stop or fail
|
|
if tun.debug {
|
|
log.Printf("Listening on %s", local)
|
|
}
|
|
tun.started = true
|
|
if tun.connState != nil {
|
|
tun.connState(tun, StateStarted)
|
|
}
|
|
tun.Unlock()
|
|
|
|
// Wait to exit
|
|
errFromCh := <-tun.errCh
|
|
return errFromCh
|
|
|
|
}
|
|
|
|
func (tun *SSHTun) errNotStarted(err error) error {
|
|
tun.started = false
|
|
if tun.connState != nil {
|
|
tun.connState(tun, StateStopped)
|
|
}
|
|
tun.Unlock()
|
|
return err
|
|
}
|
|
|
|
func (tun *SSHTun) errStarted(err error) {
|
|
tun.Lock()
|
|
if tun.started {
|
|
tun.cancel()
|
|
if tun.connState != nil {
|
|
tun.connState(tun, StateStopped)
|
|
}
|
|
tun.started = false
|
|
tun.errCh <- err
|
|
}
|
|
tun.Unlock()
|
|
}
|
|
|
|
func (tun *SSHTun) initSSHConfig() (*ssh.ClientConfig, error) {
|
|
config := &ssh.ClientConfig{
|
|
User: tun.user,
|
|
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
return nil
|
|
},
|
|
Timeout: tun.timeout,
|
|
}
|
|
|
|
authMethod, err := tun.getSSHAuthMethod()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
config.Auth = []ssh.AuthMethod{authMethod}
|
|
|
|
return config, nil
|
|
}
|
|
|
|
func (tun *SSHTun) getSSHAuthMethod() (ssh.AuthMethod, error) {
|
|
switch tun.authType {
|
|
case AuthTypeKeyFile:
|
|
return tun.getSSHAuthMethodForKeyFile(false)
|
|
case AuthTypeEncryptedKeyFile:
|
|
return tun.getSSHAuthMethodForKeyFile(true)
|
|
case AuthTypeKeyReader:
|
|
return tun.getSSHAuthMethodForKeyReader(false)
|
|
case AuthTypeEncryptedKeyReader:
|
|
return tun.getSSHAuthMethodForKeyReader(true)
|
|
case AuthTypePassword:
|
|
return ssh.Password(tun.authPassword), nil
|
|
case AuthTypeSSHAgent:
|
|
return tun.getSSHAuthMethodForSSHAgent()
|
|
case AuthTypeAuto:
|
|
method, err := tun.getSSHAuthMethodForKeyFile(false)
|
|
if err != nil {
|
|
return tun.getSSHAuthMethodForSSHAgent()
|
|
}
|
|
return method, nil
|
|
default:
|
|
return nil, fmt.Errorf("unknown auth type: %d", tun.authType)
|
|
}
|
|
}
|
|
|
|
func (tun *SSHTun) getSSHAuthMethodForKeyFile(encrypted bool) (ssh.AuthMethod, error) {
|
|
if tun.authKeyFile == "" {
|
|
usr, _ := user.Current()
|
|
if usr != nil {
|
|
tun.authKeyFile = usr.HomeDir + "/.ssh/id_rsa"
|
|
} else {
|
|
tun.authKeyFile = "/root/.ssh/id_rsa"
|
|
}
|
|
}
|
|
buf, err := ioutil.ReadFile(tun.authKeyFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading SSH key file %s: %s", tun.authKeyFile, err.Error())
|
|
}
|
|
key, err := tun.parsePrivateKey(buf, encrypted)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading SSH key file %s: %s", tun.authKeyFile, err.Error())
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (tun *SSHTun) getSSHAuthMethodForKeyReader(encrypted bool) (ssh.AuthMethod, error) {
|
|
buf, err := ioutil.ReadAll(tun.authKeyReader)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading from SSH key reader: %s", err.Error())
|
|
}
|
|
key, err := tun.parsePrivateKey(buf, encrypted)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading from SSH key reader: %s", err.Error())
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (tun *SSHTun) parsePrivateKey(buf []byte, encrypted bool) (ssh.AuthMethod, error) {
|
|
var key ssh.Signer
|
|
var err error
|
|
if encrypted {
|
|
key, err = ssh.ParsePrivateKeyWithPassphrase(buf, []byte(tun.authPassword))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error parsing encrypted key: %s", err.Error())
|
|
}
|
|
} else {
|
|
key, err = ssh.ParsePrivateKey(buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error parsing key: %s", err.Error())
|
|
}
|
|
}
|
|
return ssh.PublicKeys(key), nil
|
|
}
|
|
|
|
func (tun *SSHTun) getSSHAuthMethodForSSHAgent() (ssh.AuthMethod, error) {
|
|
conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error opening unix socket: %s", err)
|
|
}
|
|
|
|
agentClient := agent.NewClient(conn)
|
|
|
|
signers, err := agentClient.Signers()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting ssh-agent signers: %s", err)
|
|
}
|
|
|
|
if len(signers) == 0 {
|
|
return nil, fmt.Errorf("no signers from ssh-agent. Use 'ssh-add' to add keys to agent")
|
|
}
|
|
|
|
return ssh.PublicKeys(signers...), nil
|
|
}
|
|
|
|
func (tun *SSHTun) forward(localConn net.Conn, config *ssh.ClientConfig) {
|
|
defer localConn.Close()
|
|
|
|
local := tun.local.connectionString()
|
|
server := tun.server.connectionString()
|
|
remote := tun.remote.connectionString()
|
|
|
|
sshConn, err := ssh.Dial(tun.server.connectionType(), server, config)
|
|
if err != nil {
|
|
tun.errStarted(fmt.Errorf("SSH connection to %s failed: %s", server, err.Error()))
|
|
return
|
|
}
|
|
defer sshConn.Close()
|
|
if tun.debug {
|
|
log.Printf("SSH connection to %s done", server)
|
|
}
|
|
|
|
remoteConn, err := sshConn.Dial(tun.remote.connectionType(), remote)
|
|
if err != nil {
|
|
if tun.debug {
|
|
log.Printf("Remote dial to %s failed: %s", remote, err.Error())
|
|
}
|
|
return
|
|
}
|
|
defer remoteConn.Close()
|
|
if tun.debug {
|
|
log.Printf("Remote connection to %s done", remote)
|
|
}
|
|
|
|
connStr := fmt.Sprintf("%s -(tcp)> %s -(ssh)> %s -(tcp)> %s", localConn.RemoteAddr().String(), local, server, remote)
|
|
if tun.debug {
|
|
log.Printf("SSH tunnel OPEN: %s", connStr)
|
|
}
|
|
|
|
myCtx, myCancel := context.WithCancel(tun.ctx)
|
|
|
|
go func() {
|
|
_, err = io.Copy(remoteConn, localConn)
|
|
if err != nil {
|
|
//log.Printf("Error on io.Copy remote->local on connection %s: %s", connStr, err.Error())
|
|
myCancel()
|
|
return
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
_, err = io.Copy(localConn, remoteConn)
|
|
if err != nil {
|
|
//log.Printf("Error on io.Copy local->remote on connection %s: %s", connStr, err.Error())
|
|
myCancel()
|
|
return
|
|
}
|
|
}()
|
|
|
|
<-myCtx.Done()
|
|
myCancel()
|
|
if tun.debug {
|
|
log.Printf("SSH tunnel CLOSE: %s", connStr)
|
|
}
|
|
|
|
}
|
|
|
|
// Stop closes the SSH tunnel and its connections.
|
|
// After this call all Set* methods will have effect and Start can be called again.
|
|
func (tun *SSHTun) Stop() {
|
|
tun.errStarted(nil)
|
|
}
|