|
|
|
@ -23,6 +23,7 @@ import (
|
|
|
|
|
"runtime"
|
|
|
|
|
"strings"
|
|
|
|
|
"sync"
|
|
|
|
|
"sync/atomic"
|
|
|
|
|
"syscall"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
@ -64,81 +65,68 @@ func New(_ context.Context, cmd string, args ...string) (net.Conn, error) {
|
|
|
|
|
|
|
|
|
|
// commandConn implements net.Conn
|
|
|
|
|
type commandConn struct {
|
|
|
|
|
cmd *exec.Cmd
|
|
|
|
|
cmdExited bool
|
|
|
|
|
cmdWaitErr error
|
|
|
|
|
cmdMutex sync.Mutex
|
|
|
|
|
stdin io.WriteCloser
|
|
|
|
|
stdout io.ReadCloser
|
|
|
|
|
stderrMu sync.Mutex
|
|
|
|
|
stderr bytes.Buffer
|
|
|
|
|
stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed
|
|
|
|
|
stdinClosed bool
|
|
|
|
|
stdoutClosed bool
|
|
|
|
|
localAddr net.Addr
|
|
|
|
|
remoteAddr net.Addr
|
|
|
|
|
cmdMutex sync.Mutex // for cmd, cmdWaitErr
|
|
|
|
|
cmd *exec.Cmd
|
|
|
|
|
cmdWaitErr error
|
|
|
|
|
cmdExited atomic.Bool
|
|
|
|
|
stdin io.WriteCloser
|
|
|
|
|
stdout io.ReadCloser
|
|
|
|
|
stderrMu sync.Mutex // for stderr
|
|
|
|
|
stderr bytes.Buffer
|
|
|
|
|
stdinClosed atomic.Bool
|
|
|
|
|
stdoutClosed atomic.Bool
|
|
|
|
|
closing atomic.Bool
|
|
|
|
|
localAddr net.Addr
|
|
|
|
|
remoteAddr net.Addr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// killIfStdioClosed kills the cmd if both stdin and stdout are closed.
|
|
|
|
|
func (c *commandConn) killIfStdioClosed() error {
|
|
|
|
|
c.stdioClosedMu.Lock()
|
|
|
|
|
stdioClosed := c.stdoutClosed && c.stdinClosed
|
|
|
|
|
c.stdioClosedMu.Unlock()
|
|
|
|
|
if !stdioClosed {
|
|
|
|
|
return nil
|
|
|
|
|
// kill terminates the process. On Windows it kills the process directly,
|
|
|
|
|
// whereas on other platforms, a SIGTERM is sent, before forcefully terminating
|
|
|
|
|
// the process after 3 seconds.
|
|
|
|
|
func (c *commandConn) kill() {
|
|
|
|
|
if c.cmdExited.Load() {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
return c.kill()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// killAndWait tries sending SIGTERM to the process before sending SIGKILL.
|
|
|
|
|
func killAndWait(cmd *exec.Cmd) error {
|
|
|
|
|
c.cmdMutex.Lock()
|
|
|
|
|
var werr error
|
|
|
|
|
if runtime.GOOS != "windows" {
|
|
|
|
|
werrCh := make(chan error)
|
|
|
|
|
go func() { werrCh <- cmd.Wait() }()
|
|
|
|
|
cmd.Process.Signal(syscall.SIGTERM)
|
|
|
|
|
go func() { werrCh <- c.cmd.Wait() }()
|
|
|
|
|
_ = c.cmd.Process.Signal(syscall.SIGTERM)
|
|
|
|
|
select {
|
|
|
|
|
case werr = <-werrCh:
|
|
|
|
|
case <-time.After(3 * time.Second):
|
|
|
|
|
cmd.Process.Kill()
|
|
|
|
|
_ = c.cmd.Process.Kill()
|
|
|
|
|
werr = <-werrCh
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
cmd.Process.Kill()
|
|
|
|
|
werr = cmd.Wait()
|
|
|
|
|
_ = c.cmd.Process.Kill()
|
|
|
|
|
werr = c.cmd.Wait()
|
|
|
|
|
}
|
|
|
|
|
return werr
|
|
|
|
|
c.cmdWaitErr = werr
|
|
|
|
|
c.cmdMutex.Unlock()
|
|
|
|
|
c.cmdExited.Store(true)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// kill returns nil if the command terminated, regardless to the exit status.
|
|
|
|
|
func (c *commandConn) kill() error {
|
|
|
|
|
var werr error
|
|
|
|
|
c.cmdMutex.Lock()
|
|
|
|
|
if c.cmdExited {
|
|
|
|
|
werr = c.cmdWaitErr
|
|
|
|
|
} else {
|
|
|
|
|
werr = killAndWait(c.cmd)
|
|
|
|
|
c.cmdWaitErr = werr
|
|
|
|
|
c.cmdExited = true
|
|
|
|
|
}
|
|
|
|
|
c.cmdMutex.Unlock()
|
|
|
|
|
if werr == nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
wExitErr, ok := werr.(*exec.ExitError)
|
|
|
|
|
if ok {
|
|
|
|
|
if wExitErr.ProcessState.Exited() {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
// handleEOF handles io.EOF errors while reading or writing from the underlying
|
|
|
|
|
// command pipes.
|
|
|
|
|
//
|
|
|
|
|
// When we've received an EOF we expect that the command will
|
|
|
|
|
// be terminated soon. As such, we call Wait() on the command
|
|
|
|
|
// and return EOF or the error depending on whether the command
|
|
|
|
|
// exited with an error.
|
|
|
|
|
//
|
|
|
|
|
// If Wait() does not return within 10s, an error is returned
|
|
|
|
|
func (c *commandConn) handleEOF(err error) error {
|
|
|
|
|
if err != io.EOF {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return errors.Wrapf(werr, "commandconn: failed to wait")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *commandConn) onEOF(eof error) error {
|
|
|
|
|
// when we got EOF, the command is going to be terminated
|
|
|
|
|
var werr error
|
|
|
|
|
c.cmdMutex.Lock()
|
|
|
|
|
if c.cmdExited {
|
|
|
|
|
defer c.cmdMutex.Unlock()
|
|
|
|
|
|
|
|
|
|
var werr error
|
|
|
|
|
if c.cmdExited.Load() {
|
|
|
|
|
werr = c.cmdWaitErr
|
|
|
|
|
} else {
|
|
|
|
|
werrCh := make(chan error)
|
|
|
|
@ -146,18 +134,17 @@ func (c *commandConn) onEOF(eof error) error {
|
|
|
|
|
select {
|
|
|
|
|
case werr = <-werrCh:
|
|
|
|
|
c.cmdWaitErr = werr
|
|
|
|
|
c.cmdExited = true
|
|
|
|
|
c.cmdExited.Store(true)
|
|
|
|
|
case <-time.After(10 * time.Second):
|
|
|
|
|
c.cmdMutex.Unlock()
|
|
|
|
|
c.stderrMu.Lock()
|
|
|
|
|
stderr := c.stderr.String()
|
|
|
|
|
c.stderrMu.Unlock()
|
|
|
|
|
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr)
|
|
|
|
|
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, err, stderr)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
c.cmdMutex.Unlock()
|
|
|
|
|
|
|
|
|
|
if werr == nil {
|
|
|
|
|
return eof
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
c.stderrMu.Lock()
|
|
|
|
|
stderr := c.stderr.String()
|
|
|
|
@ -166,71 +153,86 @@ func (c *commandConn) onEOF(eof error) error {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func ignorableCloseError(err error) bool {
|
|
|
|
|
errS := err.Error()
|
|
|
|
|
ss := []string{
|
|
|
|
|
os.ErrClosed.Error(),
|
|
|
|
|
return strings.Contains(err.Error(), os.ErrClosed.Error())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *commandConn) Read(p []byte) (int, error) {
|
|
|
|
|
n, err := c.stdout.Read(p)
|
|
|
|
|
// check after the call to Read, since
|
|
|
|
|
// it is blocking, and while waiting on it
|
|
|
|
|
// Close might get called
|
|
|
|
|
if c.closing.Load() {
|
|
|
|
|
// If we're currently closing the connection
|
|
|
|
|
// we don't want to call onEOF
|
|
|
|
|
return n, err
|
|
|
|
|
}
|
|
|
|
|
for _, s := range ss {
|
|
|
|
|
if strings.Contains(errS, s) {
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return n, c.handleEOF(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *commandConn) Write(p []byte) (int, error) {
|
|
|
|
|
n, err := c.stdin.Write(p)
|
|
|
|
|
// check after the call to Write, since
|
|
|
|
|
// it is blocking, and while waiting on it
|
|
|
|
|
// Close might get called
|
|
|
|
|
if c.closing.Load() {
|
|
|
|
|
// If we're currently closing the connection
|
|
|
|
|
// we don't want to call onEOF
|
|
|
|
|
return n, err
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
|
|
|
|
|
return n, c.handleEOF(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// CloseRead allows commandConn to implement halfCloser
|
|
|
|
|
func (c *commandConn) CloseRead() error {
|
|
|
|
|
// NOTE: maybe already closed here
|
|
|
|
|
if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
|
|
|
|
|
logrus.Warnf("commandConn.CloseRead: %v", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
c.stdioClosedMu.Lock()
|
|
|
|
|
c.stdoutClosed = true
|
|
|
|
|
c.stdioClosedMu.Unlock()
|
|
|
|
|
if err := c.killIfStdioClosed(); err != nil {
|
|
|
|
|
logrus.Warnf("commandConn.CloseRead: %v", err)
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
c.stdoutClosed.Store(true)
|
|
|
|
|
|
|
|
|
|
func (c *commandConn) Read(p []byte) (int, error) {
|
|
|
|
|
n, err := c.stdout.Read(p)
|
|
|
|
|
if err == io.EOF {
|
|
|
|
|
err = c.onEOF(err)
|
|
|
|
|
if c.stdinClosed.Load() {
|
|
|
|
|
c.kill()
|
|
|
|
|
}
|
|
|
|
|
return n, err
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// CloseWrite allows commandConn to implement halfCloser
|
|
|
|
|
func (c *commandConn) CloseWrite() error {
|
|
|
|
|
// NOTE: maybe already closed here
|
|
|
|
|
if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
|
|
|
|
|
logrus.Warnf("commandConn.CloseWrite: %v", err)
|
|
|
|
|
}
|
|
|
|
|
c.stdioClosedMu.Lock()
|
|
|
|
|
c.stdinClosed = true
|
|
|
|
|
c.stdioClosedMu.Unlock()
|
|
|
|
|
if err := c.killIfStdioClosed(); err != nil {
|
|
|
|
|
logrus.Warnf("commandConn.CloseWrite: %v", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
c.stdinClosed.Store(true)
|
|
|
|
|
|
|
|
|
|
func (c *commandConn) Write(p []byte) (int, error) {
|
|
|
|
|
n, err := c.stdin.Write(p)
|
|
|
|
|
if err == io.EOF {
|
|
|
|
|
err = c.onEOF(err)
|
|
|
|
|
if c.stdoutClosed.Load() {
|
|
|
|
|
c.kill()
|
|
|
|
|
}
|
|
|
|
|
return n, err
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Close is the net.Conn func that gets called
|
|
|
|
|
// by the transport when a dial is cancelled
|
|
|
|
|
// due to it's context timing out. Any blocked
|
|
|
|
|
// Read or Write calls will be unblocked and
|
|
|
|
|
// return errors. It will block until the underlying
|
|
|
|
|
// command has terminated.
|
|
|
|
|
func (c *commandConn) Close() error {
|
|
|
|
|
var err error
|
|
|
|
|
if err = c.CloseRead(); err != nil {
|
|
|
|
|
c.closing.Store(true)
|
|
|
|
|
defer c.closing.Store(false)
|
|
|
|
|
|
|
|
|
|
if err := c.CloseRead(); err != nil {
|
|
|
|
|
logrus.Warnf("commandConn.Close: CloseRead: %v", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if err = c.CloseWrite(); err != nil {
|
|
|
|
|
if err := c.CloseWrite(); err != nil {
|
|
|
|
|
logrus.Warnf("commandConn.Close: CloseWrite: %v", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return err
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *commandConn) LocalAddr() net.Addr {
|
|
|
|
|