diff --git a/go.mod b/go.mod index 166730c8..ff5d1149 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 // indirect github.com/ghodss/yaml v1.0.0 // indirect github.com/go-sql-driver/mysql v1.4.1 // indirect + github.com/gofrs/flock v0.7.0 github.com/gofrs/uuid v3.2.0+incompatible // indirect github.com/gogo/googleapis v1.1.0 // indirect github.com/gogo/protobuf v1.2.1 // indirect diff --git a/store/store.go b/store/store.go new file mode 100644 index 00000000..0f2bd4fd --- /dev/null +++ b/store/store.go @@ -0,0 +1,258 @@ +package store + +import ( + "encoding/json" + "io" + "io/ioutil" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + + "github.com/gofrs/flock" + "github.com/opencontainers/go-digest" + "github.com/pkg/errors" +) + +type NodeGroup struct { + Name string + Driver string + Nodes []Node + Endpoint string +} + +type Node struct { + Name string + Endpoint string + Platforms []string +} + +func NewStore(root string) (*Store, error) { + root = filepath.Join(root, "buildx") + if err := os.MkdirAll(filepath.Join(root, "instances"), 0700); err != nil { + return nil, err + } + if err := os.MkdirAll(filepath.Join(root, "defaults"), 0700); err != nil { + return nil, err + } + return &Store{root: root}, nil +} + +type Store struct { + root string +} + +func (s *Store) Txn() (*Txn, func(), error) { + l := flock.New(filepath.Join(s.root, ".lock")) + if err := l.Lock(); err != nil { + return nil, nil, err + } + return &Txn{ + s: s, + }, func() { + l.Close() + }, nil +} + +type Txn struct { + s *Store +} + +func (t *Txn) List() ([]*NodeGroup, error) { + pp := filepath.Join(t.s.root, "instances") + fis, err := ioutil.ReadDir(pp) + if err != nil { + return nil, err + } + ngs := make([]*NodeGroup, 0, len(fis)) + for _, fi := range fis { + ng, err := t.NodeGroupByName(fi.Name()) + if err != nil { + if os.IsNotExist(errors.Cause(err)) { + os.RemoveAll(filepath.Join(pp, fi.Name())) + continue + } + return nil, err + } + ngs = append(ngs, ng) + } + + sort.Slice(ngs, func(i, j int) bool { + return ngs[i].Name < ngs[j].Name + }) + + return ngs, nil +} + +func (t *Txn) NodeGroupByName(name string) (*NodeGroup, error) { + name, err := validateName(name) + if err != nil { + return nil, err + } + dt, err := ioutil.ReadFile(filepath.Join(t.s.root, "instances", name)) + if err != nil { + return nil, err + } + var ng NodeGroup + if err := json.Unmarshal(dt, &ng); err != nil { + return nil, err + } + return &ng, nil +} + +func (t *Txn) Save(ng *NodeGroup) error { + name, err := validateName(ng.Name) + if err != nil { + return err + } + dt, err := json.Marshal(ng) + if err != nil { + return err + } + return atomicWriteFile(filepath.Join(t.s.root, "instances", name), dt, 0600) +} + +func (t *Txn) Remove(name string) error { + name, err := validateName(name) + if err != nil { + return err + } + return os.RemoveAll(filepath.Join(t.s.root, "instances", name)) +} + +func (t *Txn) SetCurrent(key, name string, global, def bool) error { + c := current{ + Key: key, + Name: name, + Global: global, + } + dt, err := json.Marshal(c) + if err != nil { + return err + } + if err := atomicWriteFile(filepath.Join(t.s.root, "current"), dt, 0600); err != nil { + return err + } + + h := toHash(key) + + if def { + if err := atomicWriteFile(filepath.Join(t.s.root, "defaults", h), []byte(name), 0600); err != nil { + return err + } + } else { + os.RemoveAll(filepath.Join(t.s.root, "defaults", h)) // ignore error + } + return nil +} + +func (t *Txn) reset(key string) error { + dt, err := json.Marshal(current{Key: key}) + if err != nil { + return err + } + if err := atomicWriteFile(filepath.Join(t.s.root, "current"), dt, 0600); err != nil { + return err + } + return nil +} + +func (t *Txn) Current(key string) (*NodeGroup, error) { + dt, err := ioutil.ReadFile(filepath.Join(t.s.root, "current")) + if err != nil { + if !os.IsNotExist(err) { + return nil, err + } + } + if err == nil { + var c current + if err := json.Unmarshal(dt, &c); err != nil { + return nil, err + } + if c.Name != "" { + if c.Global { + ng, err := t.NodeGroupByName(c.Name) + if err == nil { + return ng, nil + } + } + + if c.Key == key { + ng, err := t.NodeGroupByName(c.Name) + if err == nil { + return ng, nil + } + return nil, nil + } + } + } + + h := toHash(key) + + dt, err = ioutil.ReadFile(filepath.Join(t.s.root, "defaults", h)) + if err != nil { + if os.IsNotExist(err) { + t.reset(key) + return nil, nil + } + return nil, err + } + + ng, err := t.NodeGroupByName(string(dt)) + if err != nil { + t.reset(key) + } + if err := t.SetCurrent(key, string(dt), false, true); err != nil { + return nil, err + } + return ng, nil +} + +type current struct { + Key string + Name string + Global bool +} + +var namePattern = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9\.\-_\+]*$`) + +func validateName(s string) (string, error) { + if !namePattern.MatchString(s) { + return "", errors.Errorf("invalid name %s, name needs to start with a letter and may not contain symbols, except ._-", s) + } + return strings.ToLower(s), nil +} + +func atomicWriteFile(filename string, data []byte, perm os.FileMode) error { + f, err := ioutil.TempFile(filepath.Dir(filename), ".tmp-"+filepath.Base(filename)) + if err != nil { + return err + } + err = os.Chmod(f.Name(), perm) + if err != nil { + f.Close() + return err + } + n, err := f.Write(data) + if err == nil && n < len(data) { + f.Close() + return io.ErrShortWrite + } + if err != nil { + f.Close() + return err + } + if err := f.Sync(); err != nil { + f.Close() + return err + } + if err := f.Close(); err != nil { + return err + } + return os.Rename(f.Name(), filename) +} + +func toHash(in string) string { + return digest.FromBytes([]byte(in)).Hex()[:20] +} diff --git a/store/store_test.go b/store/store_test.go new file mode 100644 index 00000000..2dd59f58 --- /dev/null +++ b/store/store_test.go @@ -0,0 +1,237 @@ +package store + +import ( + "io/ioutil" + "os" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestEmptyStartup(t *testing.T) { + t.Parallel() + tmpdir, err := ioutil.TempDir("", "buildx-store") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + s, err := New(tmpdir) + require.NoError(t, err) + + txn, close, err := s.Txn() + require.NoError(t, err) + defer close() + + ng, err := txn.Current("foo") + require.NoError(t, err) + require.Nil(t, ng) +} + +func TestNodeLocking(t *testing.T) { + t.Parallel() + tmpdir, err := ioutil.TempDir("", "buildx-store") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + s, err := New(tmpdir) + require.NoError(t, err) + + _, release, err := s.Txn() + require.NoError(t, err) + + ready := make(chan struct{}) + + go func() { + _, release, err := s.Txn() + require.NoError(t, err) + release() + close(ready) + }() + + select { + case <-time.After(100 * time.Millisecond): + case <-ready: + require.Fail(t, "transaction should have waited") + } + + release() + select { + case <-time.After(200 * time.Millisecond): + require.Fail(t, "transaction should have completed") + case <-ready: + } +} + +func TestNodeManagement(t *testing.T) { + t.Parallel() + tmpdir, err := ioutil.TempDir("", "buildx-store") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + s, err := New(tmpdir) + require.NoError(t, err) + + txn, release, err := s.Txn() + require.NoError(t, err) + defer release() + + err = txn.Save(&NodeGroup{ + Name: "foo/bar", + Driver: "driver", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid name") + + err = txn.Save(&NodeGroup{ + Name: "mybuild", + Driver: "mydriver", + }) + require.NoError(t, err) + + ng, err := txn.NodeGroupByName("mybuild") + require.NoError(t, err) + require.Equal(t, "mybuild", ng.Name) + require.Equal(t, "mydriver", ng.Driver) + + _, err = txn.NodeGroupByName("mybuild2") + require.Error(t, err) + require.True(t, os.IsNotExist(errors.Cause(err))) + + err = txn.Save(&NodeGroup{ + Name: "mybuild2", + Driver: "mydriver2", + }) + require.NoError(t, err) + + ng, err = txn.NodeGroupByName("mybuild2") + require.NoError(t, err) + require.Equal(t, "mybuild2", ng.Name) + require.Equal(t, "mydriver2", ng.Driver) + + // update existing + err = txn.Save(&NodeGroup{ + Name: "mybuild", + Driver: "mydriver-mod", + }) + require.NoError(t, err) + + ng, err = txn.NodeGroupByName("mybuild") + require.NoError(t, err) + require.Equal(t, "mybuild", ng.Name) + require.Equal(t, "mydriver-mod", ng.Driver) + + ngs, err := txn.List() + require.NoError(t, err) + require.Equal(t, 2, len(ngs)) + + // test setting current + err = txn.SetCurrent("foo", "mybuild", false, false) + require.NoError(t, err) + + ng, err = txn.Current("foo") + require.NoError(t, err) + require.NotNil(t, ng) + require.Equal(t, "mybuild", ng.Name) + + ng, err = txn.Current("foo") + require.NoError(t, err) + require.NotNil(t, ng) + require.Equal(t, "mybuild", ng.Name) + + ng, err = txn.Current("bar") + require.NoError(t, err) + require.Nil(t, ng) + + ng, err = txn.Current("foo") + require.NoError(t, err) + require.Nil(t, ng) + + // set with default + err = txn.SetCurrent("foo", "mybuild", false, true) + require.NoError(t, err) + + ng, err = txn.Current("foo") + require.NoError(t, err) + require.NotNil(t, ng) + require.Equal(t, "mybuild", ng.Name) + + ng, err = txn.Current("bar") + require.NoError(t, err) + require.Nil(t, ng) + + ng, err = txn.Current("foo") + require.NoError(t, err) + require.NotNil(t, ng) + require.Equal(t, "mybuild", ng.Name) + + err = txn.SetCurrent("foo", "mybuild2", false, true) + require.NoError(t, err) + + ng, err = txn.Current("foo") + require.NoError(t, err) + require.NotNil(t, ng) + require.Equal(t, "mybuild2", ng.Name) + + err = txn.SetCurrent("bar", "mybuild", false, false) + require.NoError(t, err) + + ng, err = txn.Current("bar") + require.NoError(t, err) + require.NotNil(t, ng) + require.Equal(t, "mybuild", ng.Name) + + ng, err = txn.Current("foo") + require.NoError(t, err) + require.NotNil(t, ng) + require.Equal(t, "mybuild2", ng.Name) + + // set global + err = txn.SetCurrent("foo", "mybuild2", true, false) + require.NoError(t, err) + + ng, err = txn.Current("foo") + require.NoError(t, err) + require.NotNil(t, ng) + require.Equal(t, "mybuild2", ng.Name) + + ng, err = txn.Current("bar") + require.NoError(t, err) + require.NotNil(t, ng) + require.Equal(t, "mybuild2", ng.Name) + + err = txn.SetCurrent("bar", "mybuild", false, false) + require.NoError(t, err) + + ng, err = txn.Current("bar") + require.NoError(t, err) + require.NotNil(t, ng) + require.Equal(t, "mybuild", ng.Name) + + ng, err = txn.Current("foo") + require.NoError(t, err) + require.Nil(t, ng) + + err = txn.SetCurrent("bar", "mybuild", false, true) + require.NoError(t, err) + + err = txn.SetCurrent("foo", "mybuild2", false, false) + require.NoError(t, err) + + // test removal + err = txn.Remove("mybuild2") + require.NoError(t, err) + + _, err = txn.NodeGroupByName("mybuild2") + require.Error(t, err) + require.True(t, os.IsNotExist(errors.Cause(err))) + + ng, err = txn.Current("foo") + require.NoError(t, err) + require.Nil(t, ng) + + ng, err = txn.Current("bar") + require.NoError(t, err) + require.NotNil(t, ng) + require.Equal(t, "mybuild", ng.Name) +}