You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
432 lines
9.1 KiB
Go
432 lines
9.1 KiB
Go
package remote
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/docker/buildx/controller/pb"
|
|
"github.com/moby/sys/signal"
|
|
"github.com/sirupsen/logrus"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
type msgStream interface {
|
|
Send(*pb.Message) error
|
|
Recv() (*pb.Message, error)
|
|
}
|
|
|
|
type ioServerConfig struct {
|
|
stdin io.WriteCloser
|
|
stdout, stderr io.ReadCloser
|
|
|
|
// signalFn is a callback function called when a signal is reached to the client.
|
|
signalFn func(context.Context, syscall.Signal) error
|
|
|
|
// resizeFn is a callback function called when a resize event is reached to the client.
|
|
resizeFn func(context.Context, winSize) error
|
|
}
|
|
|
|
func serveIO(attachCtx context.Context, srv msgStream, initFn func(*pb.InitMessage) error, ioConfig *ioServerConfig) (err error) {
|
|
stdin, stdout, stderr := ioConfig.stdin, ioConfig.stdout, ioConfig.stderr
|
|
stream := &debugStream{srv, "server=" + time.Now().String()}
|
|
eg, ctx := errgroup.WithContext(attachCtx)
|
|
done := make(chan struct{})
|
|
|
|
msg, err := receive(ctx, stream)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
init := msg.GetInit()
|
|
if init == nil {
|
|
return fmt.Errorf("unexpected message: %T; wanted init", msg.GetInput())
|
|
}
|
|
ref := init.Ref
|
|
if ref == "" {
|
|
return fmt.Errorf("no ref is provided")
|
|
}
|
|
if err := initFn(init); err != nil {
|
|
return fmt.Errorf("failed to initialize IO server: %w", err)
|
|
}
|
|
|
|
if stdout != nil {
|
|
stdoutReader, stdoutWriter := io.Pipe()
|
|
eg.Go(func() error {
|
|
<-done
|
|
return stdoutWriter.Close()
|
|
})
|
|
|
|
go func() {
|
|
// do not wait for read completion but return here and let the caller send EOF
|
|
// this allows us to return on ctx.Done() without being blocked by this reader.
|
|
io.Copy(stdoutWriter, stdout)
|
|
stdoutWriter.Close()
|
|
}()
|
|
|
|
eg.Go(func() error {
|
|
defer stdoutReader.Close()
|
|
return copyToStream(1, stream, stdoutReader)
|
|
})
|
|
}
|
|
|
|
if stderr != nil {
|
|
stderrReader, stderrWriter := io.Pipe()
|
|
eg.Go(func() error {
|
|
<-done
|
|
return stderrWriter.Close()
|
|
})
|
|
|
|
go func() {
|
|
// do not wait for read completion but return here and let the caller send EOF
|
|
// this allows us to return on ctx.Done() without being blocked by this reader.
|
|
io.Copy(stderrWriter, stderr)
|
|
stderrWriter.Close()
|
|
}()
|
|
|
|
eg.Go(func() error {
|
|
defer stderrReader.Close()
|
|
return copyToStream(2, stream, stderrReader)
|
|
})
|
|
}
|
|
|
|
msgCh := make(chan *pb.Message)
|
|
eg.Go(func() error {
|
|
defer close(msgCh)
|
|
for {
|
|
msg, err := receive(ctx, stream)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
select {
|
|
case msgCh <- msg:
|
|
case <-done:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
})
|
|
|
|
eg.Go(func() error {
|
|
defer close(done)
|
|
for {
|
|
var msg *pb.Message
|
|
select {
|
|
case msg = <-msgCh:
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
if msg == nil {
|
|
return nil
|
|
}
|
|
if file := msg.GetFile(); file != nil {
|
|
if file.Fd != 0 {
|
|
return fmt.Errorf("unexpected fd: %v", file.Fd)
|
|
}
|
|
if stdin == nil {
|
|
continue // no stdin destination is specified so ignore the data
|
|
}
|
|
if len(file.Data) > 0 {
|
|
_, err := stdin.Write(file.Data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if file.EOF {
|
|
stdin.Close()
|
|
}
|
|
} else if resize := msg.GetResize(); resize != nil {
|
|
if ioConfig.resizeFn != nil {
|
|
ioConfig.resizeFn(ctx, winSize{
|
|
cols: resize.Cols,
|
|
rows: resize.Rows,
|
|
})
|
|
}
|
|
} else if sig := msg.GetSignal(); sig != nil {
|
|
if ioConfig.signalFn != nil {
|
|
syscallSignal, ok := signal.SignalMap[sig.Name]
|
|
if !ok {
|
|
continue
|
|
}
|
|
ioConfig.signalFn(ctx, syscallSignal)
|
|
}
|
|
} else {
|
|
return fmt.Errorf("unexpected message: %T", msg.GetInput())
|
|
}
|
|
}
|
|
})
|
|
|
|
return eg.Wait()
|
|
}
|
|
|
|
type ioAttachConfig struct {
|
|
stdin io.ReadCloser
|
|
stdout, stderr io.WriteCloser
|
|
signal <-chan syscall.Signal
|
|
resize <-chan winSize
|
|
}
|
|
|
|
type winSize struct {
|
|
rows uint32
|
|
cols uint32
|
|
}
|
|
|
|
func attachIO(ctx context.Context, stream msgStream, initMessage *pb.InitMessage, cfg ioAttachConfig) (retErr error) {
|
|
eg, ctx := errgroup.WithContext(ctx)
|
|
done := make(chan struct{})
|
|
|
|
if err := stream.Send(&pb.Message{
|
|
Input: &pb.Message_Init{
|
|
Init: initMessage,
|
|
},
|
|
}); err != nil {
|
|
return fmt.Errorf("failed to init: %w", err)
|
|
}
|
|
|
|
if cfg.stdin != nil {
|
|
stdinReader, stdinWriter := io.Pipe()
|
|
eg.Go(func() error {
|
|
<-done
|
|
return stdinWriter.Close()
|
|
})
|
|
|
|
go func() {
|
|
// do not wait for read completion but return here and let the caller send EOF
|
|
// this allows us to return on ctx.Done() without being blocked by this reader.
|
|
io.Copy(stdinWriter, cfg.stdin)
|
|
stdinWriter.Close()
|
|
}()
|
|
|
|
eg.Go(func() error {
|
|
defer stdinReader.Close()
|
|
return copyToStream(0, stream, stdinReader)
|
|
})
|
|
}
|
|
|
|
if cfg.signal != nil {
|
|
eg.Go(func() error {
|
|
for {
|
|
var sig syscall.Signal
|
|
select {
|
|
case sig = <-cfg.signal:
|
|
case <-done:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
name := sigToName[sig]
|
|
if name == "" {
|
|
continue
|
|
}
|
|
if err := stream.Send(&pb.Message{
|
|
Input: &pb.Message_Signal{
|
|
Signal: &pb.SignalMessage{
|
|
Name: name,
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return fmt.Errorf("failed to send signal: %w", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
if cfg.resize != nil {
|
|
eg.Go(func() error {
|
|
for {
|
|
var win winSize
|
|
select {
|
|
case win = <-cfg.resize:
|
|
case <-done:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
if err := stream.Send(&pb.Message{
|
|
Input: &pb.Message_Resize{
|
|
Resize: &pb.ResizeMessage{
|
|
Rows: win.rows,
|
|
Cols: win.cols,
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return fmt.Errorf("failed to send resize: %w", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
msgCh := make(chan *pb.Message)
|
|
eg.Go(func() error {
|
|
defer close(msgCh)
|
|
for {
|
|
msg, err := receive(ctx, stream)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
select {
|
|
case msgCh <- msg:
|
|
case <-done:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
})
|
|
|
|
eg.Go(func() error {
|
|
eofs := make(map[uint32]struct{})
|
|
defer close(done)
|
|
for {
|
|
var msg *pb.Message
|
|
select {
|
|
case msg = <-msgCh:
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
if msg == nil {
|
|
return nil
|
|
}
|
|
if file := msg.GetFile(); file != nil {
|
|
if _, ok := eofs[file.Fd]; ok {
|
|
continue
|
|
}
|
|
var out io.WriteCloser
|
|
switch file.Fd {
|
|
case 1:
|
|
out = cfg.stdout
|
|
case 2:
|
|
out = cfg.stderr
|
|
default:
|
|
return fmt.Errorf("unsupported fd %d", file.Fd)
|
|
|
|
}
|
|
if out == nil {
|
|
logrus.Warnf("attachIO: no writer for fd %d", file.Fd)
|
|
continue
|
|
}
|
|
if len(file.Data) > 0 {
|
|
if _, err := out.Write(file.Data); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if file.EOF {
|
|
eofs[file.Fd] = struct{}{}
|
|
}
|
|
} else {
|
|
return fmt.Errorf("unexpected message: %T", msg.GetInput())
|
|
}
|
|
}
|
|
})
|
|
|
|
return eg.Wait()
|
|
}
|
|
|
|
func receive(ctx context.Context, stream msgStream) (*pb.Message, error) {
|
|
msgCh := make(chan *pb.Message)
|
|
errCh := make(chan error)
|
|
go func() {
|
|
msg, err := stream.Recv()
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
return
|
|
}
|
|
errCh <- err
|
|
return
|
|
}
|
|
msgCh <- msg
|
|
}()
|
|
select {
|
|
case msg := <-msgCh:
|
|
return msg, nil
|
|
case err := <-errCh:
|
|
return nil, err
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
func copyToStream(fd uint32, snd msgStream, r io.Reader) error {
|
|
for {
|
|
buf := make([]byte, 32*1024)
|
|
n, err := r.Read(buf)
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break // break loop and send EOF
|
|
}
|
|
return err
|
|
} else if n > 0 {
|
|
if snd.Send(&pb.Message{
|
|
Input: &pb.Message_File{
|
|
File: &pb.FdMessage{
|
|
Fd: fd,
|
|
Data: buf[:n],
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return snd.Send(&pb.Message{
|
|
Input: &pb.Message_File{
|
|
File: &pb.FdMessage{
|
|
Fd: fd,
|
|
EOF: true,
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
var sigToName = map[syscall.Signal]string{}
|
|
|
|
func init() {
|
|
for name, value := range signal.SignalMap {
|
|
sigToName[value] = name
|
|
}
|
|
}
|
|
|
|
type debugStream struct {
|
|
msgStream
|
|
prefix string
|
|
}
|
|
|
|
func (s *debugStream) Send(msg *pb.Message) error {
|
|
switch m := msg.GetInput().(type) {
|
|
case *pb.Message_File:
|
|
if m.File.EOF {
|
|
logrus.Debugf("|---> File Message (sender:%v) fd=%d, EOF", s.prefix, m.File.Fd)
|
|
} else {
|
|
logrus.Debugf("|---> File Message (sender:%v) fd=%d, %d bytes", s.prefix, m.File.Fd, len(m.File.Data))
|
|
}
|
|
case *pb.Message_Resize:
|
|
logrus.Debugf("|---> Resize Message (sender:%v): %+v", s.prefix, m.Resize)
|
|
case *pb.Message_Signal:
|
|
logrus.Debugf("|---> Signal Message (sender:%v): %s", s.prefix, m.Signal.Name)
|
|
}
|
|
return s.msgStream.Send(msg)
|
|
}
|
|
|
|
func (s *debugStream) Recv() (*pb.Message, error) {
|
|
msg, err := s.msgStream.Recv()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
switch m := msg.GetInput().(type) {
|
|
case *pb.Message_File:
|
|
if m.File.EOF {
|
|
logrus.Debugf("|<--- File Message (receiver:%v) fd=%d, EOF", s.prefix, m.File.Fd)
|
|
} else {
|
|
logrus.Debugf("|<--- File Message (receiver:%v) fd=%d, %d bytes", s.prefix, m.File.Fd, len(m.File.Data))
|
|
}
|
|
case *pb.Message_Resize:
|
|
logrus.Debugf("|<--- Resize Message (receiver:%v): %+v", s.prefix, m.Resize)
|
|
case *pb.Message_Signal:
|
|
logrus.Debugf("|<--- Signal Message (receiver:%v): %s", s.prefix, m.Signal.Name)
|
|
}
|
|
return msg, nil
|
|
}
|