breach/breacher/ssh_lib.go

500 lines
14 KiB
Go

package breacher
import (
"context"
"fmt"
"io"
"log"
"net"
"os"
"sync"
"time"
"io/ioutil"
"os/user"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
type Endpoint struct {
Host string
Port int
UnixSocket string
}
func (e *Endpoint) connectionString() string {
if e.UnixSocket != "" {
return e.UnixSocket
}
return fmt.Sprintf("%s:%d", e.Host, e.Port)
}
func (e *Endpoint) connectionType() string {
if e.UnixSocket != "" {
return "unix"
}
return "tcp"
}
// AuthType is the type of authentication to use for SSH.
type AuthType int
const (
// AuthTypeKeyFile uses the keys from a SSH key file read from the system.
AuthTypeKeyFile AuthType = iota
// AuthTypeEncryptedKeyFile uses the keys from an encrypted SSH key file read from the system.
AuthTypeEncryptedKeyFile
// AuthTypeKeyReader uses the keys from a SSH key reader.
AuthTypeKeyReader
// AuthTypeEncryptedKeyReader uses the keys from an encrypted SSH key reader.
AuthTypeEncryptedKeyReader
// AuthTypePassword uses a password directly.
AuthTypePassword
// AuthTypeSSHAgent will use registered users in the ssh-agent.
AuthTypeSSHAgent
// AuthTypeAuto tries to get the authentication method automatically. See SSHTun.Start for details on
// this.
AuthTypeAuto
)
// SSHTun represents a SSH tunnel
type SSHTun struct {
*sync.Mutex
ctx context.Context
cancel context.CancelFunc
errCh chan error
user string
authType AuthType
authKeyFile string
authKeyReader io.Reader
authPassword string
server Endpoint
local Endpoint
remote Endpoint
started bool
timeout time.Duration
debug bool
connState func(*SSHTun, ConnState)
}
// ConnState represents the state of the SSH tunnel. It's returned to an optional function provided to SetConnState.
type ConnState int
const (
// StateStopped represents a stopped tunnel. A call to Start will make the state to transition to StateStarting.
StateStopped ConnState = iota
// StateStarting represents a tunnel initializing and preparing to listen for connections.
// A successful initialization will make the state to transition to StateStarted, otherwise it will transition to StateStopped.
StateStarting
// StateStarted represents a tunnel ready to accept connections.
// A call to stop or an error will make the state to transition to StateStopped.
StateStarted
)
// New creates a new SSH tunnel to the specified server redirecting a port on local localhost to a port on remote localhost.
// By default the SSH connection is made to port 22 as root and using automatic detection of the authentication
// method (see Start for details on this).
// Calling SetPassword will change the authentication to password based.
// Calling SetKeyFile will change the authentication to keyfile based with an optional key file.
// The SSH user and port can be changed with SetUser and SetPort.
// The local and remote hosts can be changed to something different than localhost with SetLocalHost and SetRemoteHost.
// The states of the tunnel can be received throgh a callback function with SetConnState.
func NewSSHTunnel(localHost string, localPort int, remoteHost string, remotePort int) *SSHTun {
return &SSHTun{
Mutex: &sync.Mutex{},
server: Endpoint{
Host: "",
Port: 22,
},
user: "root",
authType: AuthTypeAuto,
authKeyFile: "",
authPassword: "",
local: Endpoint{
Host: localHost,
Port: localPort,
},
remote: Endpoint{
Host: remoteHost,
Port: remotePort,
},
started: false,
timeout: time.Second * 15,
debug: false,
}
}
func NewUnix(localUnixSocket string, server string, remoteUnixSocket string) *SSHTun {
return &SSHTun{
Mutex: &sync.Mutex{},
server: Endpoint{
Host: server,
Port: 22,
},
user: "root",
authType: AuthTypeAuto,
authKeyFile: "",
authPassword: "",
local: Endpoint{
UnixSocket: localUnixSocket,
},
remote: Endpoint{
UnixSocket: remoteUnixSocket,
},
started: false,
timeout: time.Second * 15,
debug: false,
}
}
// SetPort changes the port where the SSH connection will be made.
func (tun *SSHTun) SetPort(port int) {
tun.server.Port = port
}
// SetUser changes the user used to make the SSH connection.
func (tun *SSHTun) SetUser(user string) {
tun.user = user
}
// SetKeyFile changes the authentication to key-based and uses the specified file.
// Leaving it empty defaults to the default linux private key location ($HOME/.ssh/id_rsa).
func (tun *SSHTun) SetKeyFile(file string) {
tun.authType = AuthTypeKeyFile
tun.authKeyFile = file
}
// SetEncryptedKeyFile changes the authentication to encrypted key-based and uses the specified file and password.
// Leaving it empty defaults to the default linux private key location ($HOME/.ssh/id_rsa).
func (tun *SSHTun) SetEncryptedKeyFile(file string, password string) {
tun.authType = AuthTypeEncryptedKeyFile
tun.authKeyFile = file
tun.authPassword = password
}
// SetKeyReader changes the authentication to key-based and uses the specified reader.
// Leaving it empty defaults to the default linux private key location ($HOME/.ssh/id_rsa).
func (tun *SSHTun) SetKeyReader(reader io.Reader) {
tun.authType = AuthTypeKeyReader
tun.authKeyReader = reader
}
// SetEncryptedKeyReader changes the authentication to encrypted key-based and uses the specified reader and password.
// Leaving it empty defaults to the default linux private key location ($HOME/.ssh/id_rsa).
func (tun *SSHTun) SetEncryptedKeyReader(reader io.Reader, password string) {
tun.authType = AuthTypeEncryptedKeyReader
tun.authKeyReader = reader
tun.authPassword = password
}
// SetSSHAgent changes the authentication to ssh-agent.
func (tun *SSHTun) SetSSHAgent() {
tun.authType = AuthTypeSSHAgent
}
// SetPassword changes the authentication to password-based and uses the specified password.
func (tun *SSHTun) SetPassword(password string) {
tun.authType = AuthTypePassword
tun.authPassword = password
}
// SetLocalHost sets the local host to redirect (defaults to localhost)
func (tun *SSHTun) SetLocalHost(host string) {
tun.local.Host = host
}
// SetRemoteHost sets the remote host to redirect (defaults to localhost)
func (tun *SSHTun) SetRemoteHost(host string) {
tun.remote.Host = host
}
// SetTimeout sets the connection timeouts (defaults to 15 seconds).
func (tun *SSHTun) SetTimeout(timeout time.Duration) {
tun.timeout = timeout
}
// SetDebug enables or disables log messages (disabled by default).
func (tun *SSHTun) SetDebug(debug bool) {
tun.debug = debug
}
// SetConnState specifies an optional callback function that is called when a SSH tunnel changes state.
// See the ConnState type and associated constants for details.
func (tun *SSHTun) SetConnState(connStateFun func(*SSHTun, ConnState)) {
tun.connState = connStateFun
}
// Start starts the SSH tunnel. After this call, all Set* methods will have no effect until Close is called.
// Note on SSH authentication: in case the tunnel's authType is set to AuthTypeAuto the following will happen:
// The default key file will be used, if that doesn't succeed it will try to use the SSH agent.
// If that fails the whole authentication fails.
// That means if you want to use password or encrypted key file authentication, you have to specify that explicitly.
func (tun *SSHTun) Start() error {
tun.Lock()
if tun.connState != nil {
tun.connState(tun, StateStarting)
}
// SSH config
config, err := tun.initSSHConfig()
if err != nil {
return tun.errNotStarted(err)
}
local := tun.local.connectionString()
// Local listener
localList, err := net.Listen(tun.local.connectionType(), local)
if err != nil {
return tun.errNotStarted(fmt.Errorf("local listen on %s failed: %s", local, err.Error()))
}
// Context and error channel
tun.ctx, tun.cancel = context.WithCancel(context.Background())
tun.errCh = make(chan error)
// Accept connections
go func() {
for {
localConn, err := localList.Accept()
if err != nil {
tun.errStarted(fmt.Errorf("local accept on %s failed: %s", local, err.Error()))
break
}
if tun.debug {
log.Printf("Accepted connection from %s", localConn.RemoteAddr().String())
}
// Launch the forward
go tun.forward(localConn, config)
}
}()
// Wait until someone cancels the context and stop accepting connections
go func() {
<-tun.ctx.Done()
localList.Close()
}()
// Now others can call Stop or fail
if tun.debug {
log.Printf("Listening on %s", local)
}
tun.started = true
if tun.connState != nil {
tun.connState(tun, StateStarted)
}
tun.Unlock()
// Wait to exit
errFromCh := <-tun.errCh
return errFromCh
}
func (tun *SSHTun) errNotStarted(err error) error {
tun.started = false
if tun.connState != nil {
tun.connState(tun, StateStopped)
}
tun.Unlock()
return err
}
func (tun *SSHTun) errStarted(err error) {
tun.Lock()
if tun.started {
tun.cancel()
if tun.connState != nil {
tun.connState(tun, StateStopped)
}
tun.started = false
tun.errCh <- err
}
tun.Unlock()
}
func (tun *SSHTun) initSSHConfig() (*ssh.ClientConfig, error) {
config := &ssh.ClientConfig{
User: tun.user,
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
Timeout: tun.timeout,
}
authMethod, err := tun.getSSHAuthMethod()
if err != nil {
return nil, err
}
config.Auth = []ssh.AuthMethod{authMethod}
return config, nil
}
func (tun *SSHTun) getSSHAuthMethod() (ssh.AuthMethod, error) {
switch tun.authType {
case AuthTypeKeyFile:
return tun.getSSHAuthMethodForKeyFile(false)
case AuthTypeEncryptedKeyFile:
return tun.getSSHAuthMethodForKeyFile(true)
case AuthTypeKeyReader:
return tun.getSSHAuthMethodForKeyReader(false)
case AuthTypeEncryptedKeyReader:
return tun.getSSHAuthMethodForKeyReader(true)
case AuthTypePassword:
return ssh.Password(tun.authPassword), nil
case AuthTypeSSHAgent:
return tun.getSSHAuthMethodForSSHAgent()
case AuthTypeAuto:
method, err := tun.getSSHAuthMethodForKeyFile(false)
if err != nil {
return tun.getSSHAuthMethodForSSHAgent()
}
return method, nil
default:
return nil, fmt.Errorf("unknown auth type: %d", tun.authType)
}
}
func (tun *SSHTun) getSSHAuthMethodForKeyFile(encrypted bool) (ssh.AuthMethod, error) {
if tun.authKeyFile == "" {
usr, _ := user.Current()
if usr != nil {
tun.authKeyFile = usr.HomeDir + "/.ssh/id_rsa"
} else {
tun.authKeyFile = "/root/.ssh/id_rsa"
}
}
buf, err := ioutil.ReadFile(tun.authKeyFile)
if err != nil {
return nil, fmt.Errorf("error reading SSH key file %s: %s", tun.authKeyFile, err.Error())
}
key, err := tun.parsePrivateKey(buf, encrypted)
if err != nil {
return nil, fmt.Errorf("error reading SSH key file %s: %s", tun.authKeyFile, err.Error())
}
return key, nil
}
func (tun *SSHTun) getSSHAuthMethodForKeyReader(encrypted bool) (ssh.AuthMethod, error) {
buf, err := ioutil.ReadAll(tun.authKeyReader)
if err != nil {
return nil, fmt.Errorf("error reading from SSH key reader: %s", err.Error())
}
key, err := tun.parsePrivateKey(buf, encrypted)
if err != nil {
return nil, fmt.Errorf("error reading from SSH key reader: %s", err.Error())
}
return key, nil
}
func (tun *SSHTun) parsePrivateKey(buf []byte, encrypted bool) (ssh.AuthMethod, error) {
var key ssh.Signer
var err error
if encrypted {
key, err = ssh.ParsePrivateKeyWithPassphrase(buf, []byte(tun.authPassword))
if err != nil {
return nil, fmt.Errorf("error parsing encrypted key: %s", err.Error())
}
} else {
key, err = ssh.ParsePrivateKey(buf)
if err != nil {
return nil, fmt.Errorf("error parsing key: %s", err.Error())
}
}
return ssh.PublicKeys(key), nil
}
func (tun *SSHTun) getSSHAuthMethodForSSHAgent() (ssh.AuthMethod, error) {
conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
return nil, fmt.Errorf("error opening unix socket: %s", err)
}
agentClient := agent.NewClient(conn)
signers, err := agentClient.Signers()
if err != nil {
return nil, fmt.Errorf("error getting ssh-agent signers: %s", err)
}
if len(signers) == 0 {
return nil, fmt.Errorf("no signers from ssh-agent. Use 'ssh-add' to add keys to agent")
}
return ssh.PublicKeys(signers...), nil
}
func (tun *SSHTun) forward(localConn net.Conn, config *ssh.ClientConfig) {
defer localConn.Close()
local := tun.local.connectionString()
server := tun.server.connectionString()
remote := tun.remote.connectionString()
sshConn, err := ssh.Dial(tun.server.connectionType(), server, config)
if err != nil {
tun.errStarted(fmt.Errorf("SSH connection to %s failed: %s", server, err.Error()))
return
}
defer sshConn.Close()
if tun.debug {
log.Printf("SSH connection to %s done", server)
}
remoteConn, err := sshConn.Dial(tun.remote.connectionType(), remote)
if err != nil {
if tun.debug {
log.Printf("Remote dial to %s failed: %s", remote, err.Error())
}
return
}
defer remoteConn.Close()
if tun.debug {
log.Printf("Remote connection to %s done", remote)
}
connStr := fmt.Sprintf("%s -(tcp)> %s -(ssh)> %s -(tcp)> %s", localConn.RemoteAddr().String(), local, server, remote)
if tun.debug {
log.Printf("SSH tunnel OPEN: %s", connStr)
}
myCtx, myCancel := context.WithCancel(tun.ctx)
go func() {
_, err = io.Copy(remoteConn, localConn)
if err != nil {
//log.Printf("Error on io.Copy remote->local on connection %s: %s", connStr, err.Error())
myCancel()
return
}
}()
go func() {
_, err = io.Copy(localConn, remoteConn)
if err != nil {
//log.Printf("Error on io.Copy local->remote on connection %s: %s", connStr, err.Error())
myCancel()
return
}
}()
<-myCtx.Done()
myCancel()
if tun.debug {
log.Printf("SSH tunnel CLOSE: %s", connStr)
}
}
// Stop closes the SSH tunnel and its connections.
// After this call all Set* methods will have effect and Start can be called again.
func (tun *SSHTun) Stop() {
tun.errStarted(nil)
}