diff --git a/build/build.go b/build/build.go index 2020d939..c185702d 100644 --- a/build/build.go +++ b/build/build.go @@ -615,7 +615,101 @@ func toSolveOpt(ctx context.Context, di DriverInfo, multiDriver bool, opt Option return &so, releaseF, nil } +// ContainerConfig is configuration for a container to run. +type ContainerConfig struct { + ResultCtx *ResultContext + Args []string + Env []string + User string + Cwd string + Tty bool + Stdin io.ReadCloser + Stdout io.WriteCloser + Stderr io.WriteCloser +} + +// ResultContext is a build result with the client that built it. +type ResultContext struct { + Client *client.Client + Res *gateway.Result +} + +// Invoke invokes a build result as a container. +func Invoke(ctx context.Context, cfg ContainerConfig) error { + if cfg.ResultCtx == nil { + return errors.Errorf("result must be provided") + } + c, res := cfg.ResultCtx.Client, cfg.ResultCtx.Res + _, err := c.Build(ctx, client.SolveOpt{}, "buildx", func(ctx context.Context, c gateway.Client) (*gateway.Result, error) { + if res.Ref == nil { + return nil, errors.Errorf("no reference is registered") + } + st, err := res.Ref.ToState() + if err != nil { + return nil, err + } + def, err := st.Marshal(ctx) + if err != nil { + return nil, err + } + imgRef, err := c.Solve(ctx, gateway.SolveRequest{ + Definition: def.ToPB(), + }) + if err != nil { + return nil, err + } + ctr, err := c.NewContainer(ctx, gateway.NewContainerRequest{ + Mounts: []gateway.Mount{ + { + Dest: "/", + MountType: pb.MountType_BIND, + Ref: imgRef.Ref, + }, + }, + }) + if err != nil { + return nil, err + } + defer ctr.Release(ctx) + proc, err := ctr.Start(ctx, gateway.StartRequest{ + Args: cfg.Args, + Env: cfg.Env, + User: cfg.User, + Cwd: cfg.Cwd, + Tty: cfg.Tty, + Stdin: cfg.Stdin, + Stdout: cfg.Stdout, + Stderr: cfg.Stderr, + }) + if err != nil { + return nil, errors.Errorf("failed to start container: %v", err) + } + errCh := make(chan error) + doneCh := make(chan struct{}) + go func() { + if err := proc.Wait(); err != nil { + errCh <- err + return + } + close(doneCh) + }() + select { + case <-doneCh: + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-errCh: + return nil, err + } + return nil, nil + }, nil) + return err +} + func Build(ctx context.Context, drivers []DriverInfo, opt map[string]Options, docker DockerAPI, configDir string, w progress.Writer) (resp map[string]*client.SolveResponse, err error) { + return BuildWithResultHandler(ctx, drivers, opt, docker, configDir, w, nil) +} + +func BuildWithResultHandler(ctx context.Context, drivers []DriverInfo, opt map[string]Options, docker DockerAPI, configDir string, w progress.Writer, resultHandleFunc func(driverIndex int, rCtx *ResultContext)) (resp map[string]*client.SolveResponse, err error) { if len(drivers) == 0 { return nil, errors.Errorf("driver required for build") } @@ -927,12 +1021,16 @@ func Build(ctx context.Context, drivers []DriverInfo, opt map[string]Options, do ch, done := progress.NewChannel(pw) defer func() { <-done }() + cc := c rr, err := c.Build(ctx, so, "buildx", func(ctx context.Context, c gateway.Client) (*gateway.Result, error) { res, err := c.Solve(ctx, req) if err != nil { return nil, err } results.Set(resultKey(dp.driverIndex, k), res) + if resultHandleFunc != nil { + resultHandleFunc(dp.driverIndex, &ResultContext{cc, res}) + } return res, nil }, ch) if err != nil { diff --git a/commands/build.go b/commands/build.go index 70a745b4..baf63f31 100644 --- a/commands/build.go +++ b/commands/build.go @@ -4,14 +4,19 @@ import ( "bytes" "context" "encoding/base64" + "encoding/csv" "encoding/json" "fmt" "io" "os" "path/filepath" + "strconv" "strings" + "sync" + "github.com/containerd/console" "github.com/docker/buildx/build" + "github.com/docker/buildx/monitor" "github.com/docker/buildx/util/buildflags" "github.com/docker/buildx/util/confutil" "github.com/docker/buildx/util/platformutil" @@ -63,6 +68,7 @@ type buildOptions struct { tags []string target string ulimits *dockeropts.UlimitOpt + invoke string commonOptions } @@ -225,22 +231,48 @@ func runBuild(dockerCli command.Cli, in buildOptions) (err error) { contextPathHash = in.contextPath } - imageID, err := buildTargets(ctx, dockerCli, map[string]build.Options{defaultTargetName: opts}, in.progress, contextPathHash, in.builder, in.metadataFile) + imageID, res, err := buildTargets(ctx, dockerCli, map[string]build.Options{defaultTargetName: opts}, in.progress, contextPathHash, in.builder, in.metadataFile) err = wrapBuildError(err, false) if err != nil { return err } + if in.invoke != "" { + cfg, err := parseInvokeConfig(in.invoke) + if err != nil { + return err + } + cfg.ResultCtx = res + con := console.Current() + if err := con.SetRaw(); err != nil { + return errors.Errorf("failed to configure terminal: %v", err) + } + err = monitor.RunMonitor(ctx, cfg, func(ctx context.Context) (*build.ResultContext, error) { + _, rr, err := buildTargets(ctx, dockerCli, map[string]build.Options{defaultTargetName: opts}, in.progress, contextPathHash, in.builder, in.metadataFile) + return rr, err + }, io.NopCloser(os.Stdin), nopCloser{os.Stdout}, nopCloser{os.Stderr}) + if err != nil { + logrus.Warnf("failed to run monitor: %v", err) + } + con.Reset() + } + if in.quiet { fmt.Println(imageID) } return nil } -func buildTargets(ctx context.Context, dockerCli command.Cli, opts map[string]build.Options, progressMode, contextPathHash, instance string, metadataFile string) (imageID string, err error) { +type nopCloser struct { + io.WriteCloser +} + +func (c nopCloser) Close() error { return nil } + +func buildTargets(ctx context.Context, dockerCli command.Cli, opts map[string]build.Options, progressMode, contextPathHash, instance string, metadataFile string) (imageID string, res *build.ResultContext, err error) { dis, err := getInstanceOrDefault(ctx, dockerCli, instance, contextPathHash) if err != nil { - return "", err + return "", nil, err } ctx2, cancel := context.WithCancel(context.TODO()) @@ -248,24 +280,82 @@ func buildTargets(ctx context.Context, dockerCli command.Cli, opts map[string]bu printer := progress.NewPrinter(ctx2, os.Stderr, os.Stderr, progressMode) - resp, err := build.Build(ctx, dis, opts, dockerAPI(dockerCli), confutil.ConfigDir(dockerCli), printer) + var mu sync.Mutex + var idx int + resp, err := build.BuildWithResultHandler(ctx, dis, opts, dockerAPI(dockerCli), confutil.ConfigDir(dockerCli), printer, func(driverIndex int, gotRes *build.ResultContext) { + mu.Lock() + defer mu.Unlock() + if res == nil || driverIndex < idx { + idx, res = driverIndex, gotRes + } + }) err1 := printer.Wait() if err == nil { err = err1 } if err != nil { - return "", err + return "", nil, err } if len(metadataFile) > 0 && resp != nil { if err := writeMetadataFile(metadataFile, decodeExporterResponse(resp[defaultTargetName].ExporterResponse)); err != nil { - return "", err + return "", nil, err } } printWarnings(os.Stderr, printer.Warnings(), progressMode) - return resp[defaultTargetName].ExporterResponse["containerimage.digest"], err + return resp[defaultTargetName].ExporterResponse["containerimage.digest"], res, err +} + +func parseInvokeConfig(invoke string) (cfg build.ContainerConfig, err error) { + csvReader := csv.NewReader(strings.NewReader(invoke)) + fields, err := csvReader.Read() + if err != nil { + return cfg, err + } + cfg.Tty = true + if len(fields) == 1 && !strings.Contains(fields[0], "=") { + cfg.Args = []string{fields[0]} + return cfg, nil + } + var entrypoint string + var args []string + for _, field := range fields { + parts := strings.SplitN(field, "=", 2) + if len(parts) != 2 { + return cfg, errors.Errorf("invalid value %s", field) + } + key := strings.ToLower(parts[0]) + value := parts[1] + switch key { + case "args": + args = append(args, value) // TODO: support JSON + case "entrypoint": + entrypoint = value // TODO: support JSON + case "env": + cfg.Env = append(cfg.Env, value) + case "user": + cfg.User = value + case "cwd": + cfg.Cwd = value + case "tty": + cfg.Tty, err = strconv.ParseBool(value) + if err != nil { + return cfg, errors.Errorf("failed to parse tty: %v", err) + } + default: + return cfg, errors.Errorf("unknown key %q", key) + } + } + cfg.Args = args + if entrypoint != "" { + cfg.Args = append([]string{entrypoint}, cfg.Args...) + } + if len(cfg.Args) == 0 { + cfg.Args = []string{"sh"} + } + return cfg, nil } func printWarnings(w io.Writer, warnings []client.VertexWarning, mode string) { @@ -389,6 +479,10 @@ func buildCmd(dockerCli command.Cli, rootOpts *rootOptions) *cobra.Command { flags.Var(options.ulimits, "ulimit", "Ulimit options") + if os.Getenv("BUILDX_EXPERIMENTAL") == "1" { + flags.StringVar(&options.invoke, "invoke", "", "Invoke a command after the build. BUILDX_EXPERIMENTAL=1 is required.") + } + // hidden flags var ignore string var ignoreSlice []string diff --git a/go.mod b/go.mod index 95f1ce52..d53c9e08 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( go.opentelemetry.io/otel v1.4.1 go.opentelemetry.io/otel/trace v1.4.1 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c + golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 google.golang.org/grpc v1.45.0 k8s.io/api v0.23.4 k8s.io/apimachinery v0.23.4 @@ -128,7 +129,6 @@ require ( golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f // indirect golang.org/x/sys v0.0.0-20220405210540-1e041c57c461 // indirect - golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/monitor/monitor.go b/monitor/monitor.go new file mode 100644 index 00000000..4abaa97e --- /dev/null +++ b/monitor/monitor.go @@ -0,0 +1,486 @@ +package monitor + +import ( + "bufio" + "context" + "fmt" + "io" + "sync" + + "github.com/docker/buildx/build" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "golang.org/x/term" +) + +// RunMonitor provides an interactive session for running and managing containers via specified IO. +func RunMonitor(ctx context.Context, containerConfig build.ContainerConfig, reloadFunc func(context.Context) (*build.ResultContext, error), stdin io.ReadCloser, stdout, stderr io.WriteCloser) error { + monitorIn, monitorOut := ioSetPipe() + defer monitorIn.Close() + monitorEnableCh := make(chan struct{}) + monitorDisableCh := make(chan struct{}) + monitorOutCtx := ioSetOutContext{monitorOut, + func() { monitorEnableCh <- struct{}{} }, + func() { monitorDisableCh <- struct{}{} }, + } + + containerIn, containerOut := ioSetPipe() + defer containerIn.Close() + containerOutCtx := ioSetOutContext{containerOut, + // send newline to hopefully get the prompt; TODO: better UI (e.g. reprinting the last line) + func() { containerOut.stdin.Write([]byte("\n")) }, + func() {}, + } + + m := &monitor{ + invokeIO: newIOForwarder(containerIn), + muxIO: newMuxIO(ioSetIn{stdin, stdout, stderr}, []ioSetOutContext{monitorOutCtx, containerOutCtx}, 1, "Switched IO\n"), + } + + // Start container automatically + go func() { + m.rollback(ctx, containerConfig) + }() + + // Serve monitor commands + monitorForwarder := newIOForwarder(monitorIn) + for { + <-monitorEnableCh + in, out := ioSetPipe() + monitorForwarder.setDestination(&out) + doneCh, errCh := make(chan struct{}), make(chan error) + go func() { + defer close(doneCh) + defer in.Close() + t := term.NewTerminal(readWriter{in.stdin, in.stdout}, "(buildx) ") + for { + l, err := t.ReadLine() + if err != nil { + if err != io.EOF { + errCh <- err + return + } + return + } + switch l { + case "": + // nop + case "reload": + res, err := reloadFunc(ctx) + if err != nil { + fmt.Printf("failed to reload: %v\n", err) + } else { + // rollback the running container with the new result + containerConfig.ResultCtx = res + m.rollback(ctx, containerConfig) + } + case "rollback": + m.rollback(ctx, containerConfig) + case "exit": + return + default: + fmt.Printf("unknown command: %q\n", l) + } + } + }() + select { + case <-doneCh: + return nil + case err := <-errCh: + return err + case <-monitorDisableCh: + } + monitorForwarder.setDestination(nil) + } +} + +type readWriter struct { + io.Reader + io.Writer +} + +type monitor struct { + muxIO *muxIO + invokeIO *ioForwarder + curInvokeCancel func() +} + +func (m *monitor) rollback(ctx context.Context, cfg build.ContainerConfig) { + if m.curInvokeCancel != nil { + m.curInvokeCancel() // Finish the running container if exists + } + go func() { + // Start a new container + if err := m.invoke(ctx, cfg); err != nil { + logrus.Debugf("invoke error: %v", err) + } + }() +} + +func (m *monitor) invoke(ctx context.Context, cfg build.ContainerConfig) error { + m.muxIO.enable(1) + defer m.muxIO.disable(1) + invokeCtx, invokeCancel := context.WithCancel(ctx) + + containerIn, containerOut := ioSetPipe() + m.invokeIO.setDestination(&containerOut) + waitInvokeDoneCh := make(chan struct{}) + var cancelOnce sync.Once + curInvokeCancel := func() { + cancelOnce.Do(func() { + containerIn.Close() + m.invokeIO.setDestination(nil) + invokeCancel() + }) + <-waitInvokeDoneCh + } + defer curInvokeCancel() + m.curInvokeCancel = curInvokeCancel + + cfg.Stdin = containerIn.stdin + cfg.Stdout = containerIn.stdout + cfg.Stderr = containerIn.stderr + err := build.Invoke(invokeCtx, cfg) + close(waitInvokeDoneCh) + + return err +} + +type ioForwarder struct { + curIO *ioSetOut + mu sync.Mutex + updateCh chan struct{} +} + +func newIOForwarder(in ioSetIn) *ioForwarder { + f := &ioForwarder{ + updateCh: make(chan struct{}), + } + doneCh := make(chan struct{}) + go func() { + for { + f.mu.Lock() + w := f.curIO + f.mu.Unlock() + if w != nil && w.stdout != nil && w.stderr != nil { + go func() { + if _, err := io.Copy(in.stdout, w.stdout); err != nil && err != io.ErrClosedPipe { + // ErrClosedPipe is OK as we close this read end during setDestination. + logrus.WithError(err).Warnf("failed to forward stdout: %v", err) + } + }() + go func() { + if _, err := io.Copy(in.stderr, w.stderr); err != nil && err != io.ErrClosedPipe { + // ErrClosedPipe is OK as we close this read end during setDestination. + logrus.WithError(err).Warnf("failed to forward stderr: %v", err) + } + }() + } + select { + case <-f.updateCh: + case <-doneCh: + return + } + } + }() + go func() { + if err := copyToFunc(in.stdin, func() (io.Writer, error) { + f.mu.Lock() + w := f.curIO + f.mu.Unlock() + if w != nil { + return w.stdin, nil + } + return nil, nil + }); err != nil && err != io.ErrClosedPipe { + logrus.WithError(err).Warnf("failed to forward IO: %v", err) + } + close(doneCh) + + if w := f.curIO; w != nil { + // Propagate close + if err := w.Close(); err != nil { + logrus.WithError(err).Warnf("failed to forwarded stdin IO: %v", err) + } + } + }() + return f +} + +func (f *ioForwarder) setDestination(out *ioSetOut) { + f.mu.Lock() + if f.curIO != nil { + // close all stream on the current IO no to mix with the new IO + f.curIO.Close() + } + f.curIO = out + f.mu.Unlock() + f.updateCh <- struct{}{} +} + +type ioSetOutContext struct { + ioSetOut + enableHook func() + disableHook func() +} + +// newMuxIO forwards IO stream to/from "in" and "outs". +// "outs" are closed automatically when "in" reaches EOF. +// "in" doesn't closed automatically so the caller needs to explicitly close it. +func newMuxIO(in ioSetIn, out []ioSetOutContext, initIdx int, toggleMessage string) *muxIO { + m := &muxIO{ + enabled: make(map[int]struct{}), + in: in, + out: out, + closedCh: make(chan struct{}), + toggleMessage: toggleMessage, + } + for i := range out { + m.enabled[i] = struct{}{} + } + m.maxCur = len(out) + m.cur = initIdx + var wg sync.WaitGroup + var mu sync.Mutex + for i, o := range out { + i, o := i, o + wg.Add(1) + go func() { + defer wg.Done() + if err := copyToFunc(o.stdout, func() (io.Writer, error) { + if m.cur == i { + return in.stdout, nil + } + return nil, nil + }); err != nil { + logrus.WithField("output index", i).WithError(err).Warnf("failed to write stdout") + } + if err := o.stdout.Close(); err != nil { + logrus.WithField("output index", i).WithError(err).Warnf("failed to close stdout") + } + }() + wg.Add(1) + go func() { + defer wg.Done() + if err := copyToFunc(o.stderr, func() (io.Writer, error) { + if m.cur == i { + return in.stderr, nil + } + return nil, nil + }); err != nil { + logrus.WithField("output index", i).WithError(err).Warnf("failed to write stderr") + } + if err := o.stderr.Close(); err != nil { + logrus.WithField("output index", i).WithError(err).Warnf("failed to close stderr") + } + }() + } + go func() { + errToggle := errors.Errorf("toggle IO") + for { + prevIsControlSequence := false + if err := copyToFunc(traceReader(in.stdin, func(r rune) (bool, error) { + // Toggle IO when it detects C-a-c + // TODO: make it configurable if needed + if int(r) == 1 { + prevIsControlSequence = true + return false, nil + } + defer func() { prevIsControlSequence = false }() + if prevIsControlSequence { + if string(r) == "c" { + return false, errToggle + } + } + return true, nil + }), func() (io.Writer, error) { + mu.Lock() + o := out[m.cur] + mu.Unlock() + return o.stdin, nil + }); !errors.Is(err, errToggle) { + if err != nil { + logrus.WithError(err).Warnf("failed to read stdin") + } + break + } + m.toggleIO() + } + + // propagate stdin EOF + for i, o := range out { + if err := o.stdin.Close(); err != nil { + logrus.WithError(err).Warnf("failed to close stdin of %d", i) + } + } + wg.Wait() + close(m.closedCh) + }() + return m +} + +type muxIO struct { + cur int + maxCur int + enabled map[int]struct{} + mu sync.Mutex + in ioSetIn + out []ioSetOutContext + closedCh chan struct{} + toggleMessage string +} + +func (m *muxIO) waitClosed() { + <-m.closedCh +} + +func (m *muxIO) enable(i int) { + m.mu.Lock() + defer m.mu.Unlock() + m.enabled[i] = struct{}{} +} + +func (m *muxIO) disable(i int) error { + m.mu.Lock() + defer m.mu.Unlock() + if i == 0 { + return errors.Errorf("disabling 0th io is prohibited") + } + delete(m.enabled, i) + if m.cur == i { + m.toggleIO() + } + return nil +} + +func (m *muxIO) toggleIO() { + if m.out[m.cur].disableHook != nil { + m.out[m.cur].disableHook() + } + for { + if m.cur+1 >= m.maxCur { + m.cur = 0 + } else { + m.cur++ + } + if _, ok := m.enabled[m.cur]; !ok { + continue + } + break + } + if m.out[m.cur].enableHook != nil { + m.out[m.cur].enableHook() + } + fmt.Fprintf(m.in.stdout, m.toggleMessage) +} + +func traceReader(r io.ReadCloser, f func(rune) (bool, error)) io.ReadCloser { + pr, pw := io.Pipe() + go func() { + br := bufio.NewReader(r) + for { + rn, _, err := br.ReadRune() + if err != nil { + if err == io.EOF { + pw.Close() + return + } + pw.CloseWithError(err) + return + } + if isWrite, err := f(rn); err != nil { + pw.CloseWithError(err) + return + } else if !isWrite { + continue + } + if _, err := pw.Write([]byte(string(rn))); err != nil { + pw.CloseWithError(err) + return + } + } + }() + return &readerWithClose{ + Reader: pr, + closeFunc: func() error { + pr.Close() + return r.Close() + }, + } +} + +func copyToFunc(r io.Reader, wFunc func() (io.Writer, error)) error { + buf := make([]byte, 4096) + for { + n, readErr := r.Read(buf) + if readErr != nil && readErr != io.EOF { + return readErr + } + w, err := wFunc() + if err != nil { + return err + } + if w != nil { + if _, err := w.Write(buf[:n]); err != nil { + logrus.WithError(err).Debugf("failed to copy") + } + } + if readErr == io.EOF { + return nil + } + } +} + +func ioSetPipe() (ioSetIn, ioSetOut) { + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + r3, w3 := io.Pipe() + return ioSetIn{r1, w2, w3}, ioSetOut{w1, r2, r3} +} + +type ioSetIn struct { + stdin io.ReadCloser + stdout io.WriteCloser + stderr io.WriteCloser +} + +func (s ioSetIn) Close() (retErr error) { + if err := s.stdin.Close(); err != nil { + retErr = err + } + if err := s.stdout.Close(); err != nil { + retErr = err + } + if err := s.stderr.Close(); err != nil { + retErr = err + } + return +} + +type ioSetOut struct { + stdin io.WriteCloser + stdout io.ReadCloser + stderr io.ReadCloser +} + +func (s ioSetOut) Close() (retErr error) { + if err := s.stdin.Close(); err != nil { + retErr = err + } + if err := s.stdout.Close(); err != nil { + retErr = err + } + if err := s.stderr.Close(); err != nil { + retErr = err + } + return +} + +type readerWithClose struct { + io.Reader + closeFunc func() error +} + +func (r *readerWithClose) Close() error { + return r.closeFunc() +} diff --git a/monitor/monitor_test.go b/monitor/monitor_test.go new file mode 100644 index 00000000..9eb7a062 --- /dev/null +++ b/monitor/monitor_test.go @@ -0,0 +1,321 @@ +package monitor + +import ( + "bytes" + "fmt" + "io" + "regexp" + "strings" + "testing" + + "golang.org/x/sync/errgroup" +) + +// TestMuxIO tests muxIO +func TestMuxIO(t *testing.T) { + tests := []struct { + name string + inputs []instruction + initIdx int + outputsNum int + wants []string + + // Everytime string is written to the mux stdin, the output end + // that received the string write backs to the string that is masked with + // its index number. This is useful to check if writeback is written from the + // expected output destination. + wantsMaskedOutput string + }{ + { + name: "single output", + inputs: []instruction{ + input("foo\nbar\n"), + toggle(), + input("1234"), + toggle(), + input("456"), + }, + initIdx: 0, + outputsNum: 1, + wants: []string{"foo\nbar\n1234456"}, + wantsMaskedOutput: `^0+$`, + }, + { + name: "multi output", + inputs: []instruction{ + input("foo\nbar\n"), + toggle(), + input("12" + string([]rune{rune(1)}) + "34abc"), + toggle(), + input("456"), + }, + initIdx: 0, + outputsNum: 3, + wants: []string{"foo\nbar\n", "1234abc", "456"}, + wantsMaskedOutput: `^0+1+2+$`, + }, + { + name: "multi output with nonzero index", + inputs: []instruction{ + input("foo\nbar\n"), + toggle(), + input("1234"), + toggle(), + input("456"), + }, + initIdx: 1, + outputsNum: 3, + wants: []string{"456", "foo\nbar\n", "1234"}, + wantsMaskedOutput: `^1+2+0+$`, + }, + { + name: "multi output many toggles", + inputs: []instruction{ + input("foo\nbar\n"), + toggle(), + input("1234"), + toggle(), + toggle(), + input("456"), + toggle(), + input("%%%%"), + toggle(), + toggle(), + toggle(), + input("aaaa"), + }, + initIdx: 0, + outputsNum: 3, + wants: []string{"foo\nbar\n456", "1234%%%%aaaa", ""}, + wantsMaskedOutput: `^0+1+0+1+$`, + }, + { + name: "enable disable", + inputs: []instruction{ + input("foo\nbar\n"), + toggle(), + input("1234"), + toggle(), + input("456"), + disable(2), + input("%%%%"), + enable(2), + toggle(), + toggle(), + input("aaa"), + disable(2), + disable(1), + input("1111"), + toggle(), + input("2222"), + toggle(), + input("3333"), + }, + initIdx: 0, + outputsNum: 3, + wants: []string{"foo\nbar\n%%%%111122223333", "1234", "456aaa"}, + wantsMaskedOutput: `^0+1+2+0+2+0+$`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inBuf, end, in := newTestIn(t) + var outBufs []*outBuf + var outs []ioSetOutContext + if tt.outputsNum != len(tt.wants) { + t.Fatalf("wants != outputsNum") + } + for i := 0; i < tt.outputsNum; i++ { + outBuf, out := newTestOut(t, i) + outBufs = append(outBufs, outBuf) + outs = append(outs, ioSetOutContext{out, nil, nil}) + } + mio := newMuxIO(in, outs, tt.initIdx, "") + for _, i := range tt.inputs { + // Add input to muxIO + istr, writeback := i(mio) + if _, err := end.stdin.Write([]byte(istr)); err != nil { + t.Fatalf("failed to write data to stdin: %v", err) + } + + // Wait for writeback of this input + var eg errgroup.Group + eg.Go(func() error { + outbuf := make([]byte, len(writeback)) + if _, err := io.ReadAtLeast(end.stdout, outbuf, len(outbuf)); err != nil { + return err + } + return nil + }) + eg.Go(func() error { + errbuf := make([]byte, len(writeback)) + if _, err := io.ReadAtLeast(end.stderr, errbuf, len(errbuf)); err != nil { + return err + } + return nil + }) + if err := eg.Wait(); err != nil { + t.Fatalf("failed to wait for output: %v", err) + } + } + + // Close stdin on this muxIO + end.stdin.Close() + + // Wait for all output ends reach EOF + mio.waitClosed() + + // Close stdout/stderr as well + in.Close() + + // Check if each output end received expected string + <-inBuf.doneCh + for i, o := range outBufs { + <-o.doneCh + if o.stdin != tt.wants[i] { + t.Fatalf("output[%d]: got %q; wanted %q", i, o.stdin, tt.wants[i]) + } + } + + // Check if expected string is returned from expected outputs + if !regexp.MustCompile(tt.wantsMaskedOutput).MatchString(inBuf.stdout) { + t.Fatalf("stdout: got %q; wanted %q", inBuf.stdout, tt.wantsMaskedOutput) + } + if !regexp.MustCompile(tt.wantsMaskedOutput).MatchString(inBuf.stderr) { + t.Fatalf("stderr: got %q; wanted %q", inBuf.stderr, tt.wantsMaskedOutput) + } + }) + } +} + +type instruction func(m *muxIO) (intput string, writeBackView string) + +func input(s string) instruction { + return func(m *muxIO) (string, string) { + return s, strings.ReplaceAll(s, string([]rune{rune(1)}), "") + } +} + +func toggle() instruction { + return func(m *muxIO) (string, string) { + return string([]rune{rune(1)}) + "c", "" + } +} + +func enable(i int) instruction { + return func(m *muxIO) (string, string) { + m.enable(i) + return "", "" + } +} + +func disable(i int) instruction { + return func(m *muxIO) (string, string) { + m.disable(i) + return "", "" + } +} + +type inBuf struct { + stdout string + stderr string + doneCh chan struct{} +} + +func newTestIn(t *testing.T) (*inBuf, ioSetOut, ioSetIn) { + ti := &inBuf{ + doneCh: make(chan struct{}), + } + gotOutR, gotOutW := io.Pipe() + gotErrR, gotErrW := io.Pipe() + outR, outW := io.Pipe() + var eg errgroup.Group + eg.Go(func() error { + buf := new(bytes.Buffer) + if _, err := io.Copy(io.MultiWriter(gotOutW, buf), outR); err != nil { + return err + } + ti.stdout = buf.String() + return nil + }) + errR, errW := io.Pipe() + eg.Go(func() error { + buf := new(bytes.Buffer) + if _, err := io.Copy(io.MultiWriter(gotErrW, buf), errR); err != nil { + return err + } + ti.stderr = buf.String() + return nil + }) + go func() { + eg.Wait() + close(ti.doneCh) + }() + inR, inW := io.Pipe() + return ti, ioSetOut{inW, gotOutR, gotErrR}, ioSetIn{inR, outW, errW} +} + +type outBuf struct { + idx int + stdin string + doneCh chan struct{} +} + +func newTestOut(t *testing.T, idx int) (*outBuf, ioSetOut) { + to := &outBuf{ + idx: idx, + doneCh: make(chan struct{}), + } + inR, inW := io.Pipe() + outR, outW := io.Pipe() + errR, errW := io.Pipe() + go func() { + defer inR.Close() + defer outW.Close() + defer errW.Close() + buf := new(bytes.Buffer) + mw := io.MultiWriter(buf, + writeMasked(outW, fmt.Sprintf("%d", to.idx)), + writeMasked(errW, fmt.Sprintf("%d", to.idx)), + ) + if _, err := io.Copy(mw, inR); err != nil { + inR.CloseWithError(err) + outW.CloseWithError(err) + errW.CloseWithError(err) + return + } + to.stdin = string(buf.Bytes()) + outW.Close() + errW.Close() + close(to.doneCh) + }() + return to, ioSetOut{inW, outR, errR} +} + +func writeMasked(w io.Writer, s string) io.Writer { + buf := make([]byte, 4096) + pr, pw := io.Pipe() + go func() { + for { + n, readErr := pr.Read(buf) + if readErr != nil && readErr != io.EOF { + pr.CloseWithError(readErr) + return + } + var masked string + for i := 0; i < n; i++ { + masked += s + } + if _, err := w.Write([]byte(masked)); err != nil { + pr.CloseWithError(err) + return + } + if readErr == io.EOF { + pr.Close() + return + } + } + }() + return pw +}