@ -23,6 +23,7 @@ import (
"runtime"
"runtime"
"strings"
"strings"
"sync"
"sync"
"sync/atomic"
"syscall"
"syscall"
"time"
"time"
@ -64,81 +65,68 @@ func New(_ context.Context, cmd string, args ...string) (net.Conn, error) {
// commandConn implements net.Conn
// commandConn implements net.Conn
type commandConn struct {
type commandConn struct {
cmdMutex sync . Mutex // for cmd, cmdWaitErr
cmd * exec . Cmd
cmd * exec . Cmd
cmdExited bool
cmdWaitErr error
cmdWaitErr error
cmd Mutex sync . Mutex
cmd Exited atomic . Bool
stdin io . WriteCloser
stdin io . WriteCloser
stdout io . ReadCloser
stdout io . ReadCloser
stderrMu sync . Mutex
stderrMu sync . Mutex // for stderr
stderr bytes . Buffer
stderr bytes . Buffer
stdi oClosedMu sync . Mutex // for stdinClosed and stdoutClosed
stdi nClosed atomic . Bool
std inClosed b ool
std outClosed atomic . B ool
stdoutClosed b ool
closing atomic . B ool
localAddr net . Addr
localAddr net . Addr
remoteAddr net . Addr
remoteAddr net . Addr
}
}
// killIfStdioClosed kills the cmd if both stdin and stdout are closed.
// kill terminates the process. On Windows it kills the process directly,
func ( c * commandConn ) killIfStdioClosed ( ) error {
// whereas on other platforms, a SIGTERM is sent, before forcefully terminating
c . stdioClosedMu . Lock ( )
// the process after 3 seconds.
stdioClosed := c . stdoutClosed && c . stdinClosed
func ( c * commandConn ) kill ( ) {
c . stdioClosedMu . Unlock ( )
if c . cmdExited . Load ( ) {
if ! stdioClosed {
return
return nil
}
}
return c . kill ( )
c . cmdMutex . Lock ( )
}
// killAndWait tries sending SIGTERM to the process before sending SIGKILL.
func killAndWait ( cmd * exec . Cmd ) error {
var werr error
var werr error
if runtime . GOOS != "windows" {
if runtime . GOOS != "windows" {
werrCh := make ( chan error )
werrCh := make ( chan error )
go func ( ) { werrCh <- c md. Wait ( ) } ( )
go func ( ) { werrCh <- c . cmd . Wait ( ) } ( )
cmd. Process . Signal ( syscall . SIGTERM )
_ = c . cmd. Process . Signal ( syscall . SIGTERM )
select {
select {
case werr = <- werrCh :
case werr = <- werrCh :
case <- time . After ( 3 * time . Second ) :
case <- time . After ( 3 * time . Second ) :
cmd. Process . Kill ( )
_ = c . cmd. Process . Kill ( )
werr = <- werrCh
werr = <- werrCh
}
}
} else {
} else {
cmd. Process . Kill ( )
_ = c . cmd. Process . Kill ( )
werr = c md. Wait ( )
werr = c . c md. Wait ( )
}
}
return werr
}
// kill returns nil if the command terminated, regardless to the exit status.
func ( c * commandConn ) kill ( ) error {
var werr error
c . cmdMutex . Lock ( )
if c . cmdExited {
werr = c . cmdWaitErr
} else {
werr = killAndWait ( c . cmd )
c . cmdWaitErr = werr
c . cmdWaitErr = werr
c . cmdExited = true
}
c . cmdMutex . Unlock ( )
c . cmdMutex . Unlock ( )
if werr == nil {
c . cmdExited . Store ( true )
return nil
}
wExitErr , ok := werr . ( * exec . ExitError )
if ok {
if wExitErr . ProcessState . Exited ( ) {
return nil
}
}
return errors . Wrapf ( werr , "commandconn: failed to wait" )
}
}
func ( c * commandConn ) onEOF ( eof error ) error {
// handleEOF handles io.EOF errors while reading or writing from the underlying
// when we got EOF, the command is going to be terminated
// command pipes.
var werr error
//
// When we've received an EOF we expect that the command will
// be terminated soon. As such, we call Wait() on the command
// and return EOF or the error depending on whether the command
// exited with an error.
//
// If Wait() does not return within 10s, an error is returned
func ( c * commandConn ) handleEOF ( err error ) error {
if err != io . EOF {
return err
}
c . cmdMutex . Lock ( )
c . cmdMutex . Lock ( )
if c . cmdExited {
defer c . cmdMutex . Unlock ( )
var werr error
if c . cmdExited . Load ( ) {
werr = c . cmdWaitErr
werr = c . cmdWaitErr
} else {
} else {
werrCh := make ( chan error )
werrCh := make ( chan error )
@ -146,18 +134,17 @@ func (c *commandConn) onEOF(eof error) error {
select {
select {
case werr = <- werrCh :
case werr = <- werrCh :
c . cmdWaitErr = werr
c . cmdWaitErr = werr
c . cmdExited = true
c . cmdExited . Store ( true )
case <- time . After ( 10 * time . Second ) :
case <- time . After ( 10 * time . Second ) :
c . cmdMutex . Unlock ( )
c . stderrMu . Lock ( )
c . stderrMu . Lock ( )
stderr := c . stderr . String ( )
stderr := c . stderr . String ( )
c . stderrMu . Unlock ( )
c . stderrMu . Unlock ( )
return errors . Errorf ( "command %v did not exit after %v: stderr=%q" , c . cmd . Args , e of , stderr )
return errors . Errorf ( "command %v did not exit after %v: stderr=%q" , c . cmd . Args , e rr , stderr )
}
}
}
}
c . cmdMutex . Unlock ( )
if werr == nil {
if werr == nil {
return e of
return e rr
}
}
c . stderrMu . Lock ( )
c . stderrMu . Lock ( )
stderr := c . stderr . String ( )
stderr := c . stderr . String ( )
@ -166,71 +153,86 @@ func (c *commandConn) onEOF(eof error) error {
}
}
func ignorableCloseError ( err error ) bool {
func ignorableCloseError ( err error ) bool {
errS := err . Error ( )
return strings . Contains ( err . Error ( ) , os . ErrClosed . Error ( ) )
ss := [ ] string {
}
os . ErrClosed . Error ( ) ,
}
func ( c * commandConn ) Read ( p [ ] byte ) ( int , error ) {
for _ , s := range ss {
n , err := c . stdout . Read ( p )
if strings . Contains ( errS , s ) {
// check after the call to Read, since
return true
// it is blocking, and while waiting on it
// Close might get called
if c . closing . Load ( ) {
// If we're currently closing the connection
// we don't want to call onEOF
return n , err
}
}
return n , c . handleEOF ( err )
}
func ( c * commandConn ) Write ( p [ ] byte ) ( int , error ) {
n , err := c . stdin . Write ( p )
// check after the call to Write, since
// it is blocking, and while waiting on it
// Close might get called
if c . closing . Load ( ) {
// If we're currently closing the connection
// we don't want to call onEOF
return n , err
}
}
return false
return n , c . handleEOF ( err )
}
}
// CloseRead allows commandConn to implement halfCloser
func ( c * commandConn ) CloseRead ( ) error {
func ( c * commandConn ) CloseRead ( ) error {
// NOTE: maybe already closed here
// NOTE: maybe already closed here
if err := c . stdout . Close ( ) ; err != nil && ! ignorableCloseError ( err ) {
if err := c . stdout . Close ( ) ; err != nil && ! ignorableCloseError ( err ) {
logrus . Warnf ( "commandConn.CloseRead: %v" , err )
return err
}
c . stdioClosedMu . Lock ( )
c . stdoutClosed = true
c . stdioClosedMu . Unlock ( )
if err := c . killIfStdioClosed ( ) ; err != nil {
logrus . Warnf ( "commandConn.CloseRead: %v" , err )
}
}
return nil
c . stdoutClosed . Store ( true )
}
func ( c * commandConn ) Read ( p [ ] byte ) ( int , error ) {
if c . stdinClosed . Load ( ) {
n , err := c . stdout . Read ( p )
c . kill ( )
if err == io . EOF {
err = c . onEOF ( err )
}
}
return n , err
return nil
}
}
// CloseWrite allows commandConn to implement halfCloser
func ( c * commandConn ) CloseWrite ( ) error {
func ( c * commandConn ) CloseWrite ( ) error {
// NOTE: maybe already closed here
// NOTE: maybe already closed here
if err := c . stdin . Close ( ) ; err != nil && ! ignorableCloseError ( err ) {
if err := c . stdin . Close ( ) ; err != nil && ! ignorableCloseError ( err ) {
logrus . Warnf ( "commandConn.CloseWrite: %v" , err )
return err
}
c . stdioClosedMu . Lock ( )
c . stdinClosed = true
c . stdioClosedMu . Unlock ( )
if err := c . killIfStdioClosed ( ) ; err != nil {
logrus . Warnf ( "commandConn.CloseWrite: %v" , err )
}
}
return nil
c . stdinClosed . Store ( true )
}
func ( c * commandConn ) Write ( p [ ] byte ) ( int , error ) {
if c . stdoutClosed . Load ( ) {
n , err := c . stdin . Write ( p )
c . kill ( )
if err == io . EOF {
err = c . onEOF ( err )
}
}
return n , err
return nil
}
}
// Close is the net.Conn func that gets called
// by the transport when a dial is cancelled
// due to it's context timing out. Any blocked
// Read or Write calls will be unblocked and
// return errors. It will block until the underlying
// command has terminated.
func ( c * commandConn ) Close ( ) error {
func ( c * commandConn ) Close ( ) error {
var err error
c . closing . Store ( true )
if err = c . CloseRead ( ) ; err != nil {
defer c . closing . Store ( false )
if err := c . CloseRead ( ) ; err != nil {
logrus . Warnf ( "commandConn.Close: CloseRead: %v" , err )
logrus . Warnf ( "commandConn.Close: CloseRead: %v" , err )
return err
}
}
if err = c . CloseWrite ( ) ; err != nil {
if err : = c . CloseWrite ( ) ; err != nil {
logrus . Warnf ( "commandConn.Close: CloseWrite: %v" , err )
logrus . Warnf ( "commandConn.Close: CloseWrite: %v" , err )
}
return err
return err
}
return nil
}
}
func ( c * commandConn ) LocalAddr ( ) net . Addr {
func ( c * commandConn ) LocalAddr ( ) net . Addr {