package store

import (
	"encoding/json"
	"io"
	"io/ioutil"
	"os"
	"path/filepath"
	"regexp"
	"sort"

	"github.com/gofrs/flock"
	"github.com/opencontainers/go-digest"
	"github.com/pkg/errors"
)

func New(root string) (*Store, error) {
	if err := os.MkdirAll(filepath.Join(root, "instances"), 0700); err != nil {
		return nil, err
	}
	if err := os.MkdirAll(filepath.Join(root, "defaults"), 0700); err != nil {
		return nil, err
	}
	return &Store{root: root}, nil
}

type Store struct {
	root string
}

func (s *Store) Txn() (*Txn, func(), error) {
	l := flock.New(filepath.Join(s.root, ".lock"))
	if err := l.Lock(); err != nil {
		return nil, nil, err
	}
	return &Txn{
			s: s,
		}, func() {
			l.Close()
		}, nil
}

type Txn struct {
	s *Store
}

func (t *Txn) List() ([]*NodeGroup, error) {
	pp := filepath.Join(t.s.root, "instances")
	fis, err := ioutil.ReadDir(pp)
	if err != nil {
		return nil, err
	}
	ngs := make([]*NodeGroup, 0, len(fis))
	for _, fi := range fis {
		ng, err := t.NodeGroupByName(fi.Name())
		if err != nil {
			if os.IsNotExist(errors.Cause(err)) {
				os.RemoveAll(filepath.Join(pp, fi.Name()))
				continue
			}
			return nil, err
		}
		ngs = append(ngs, ng)
	}

	sort.Slice(ngs, func(i, j int) bool {
		return ngs[i].Name < ngs[j].Name
	})

	return ngs, nil
}

func (t *Txn) NodeGroupByName(name string) (*NodeGroup, error) {
	name, err := ValidateName(name)
	if err != nil {
		return nil, err
	}
	dt, err := ioutil.ReadFile(filepath.Join(t.s.root, "instances", name))
	if err != nil {
		return nil, err
	}
	var ng NodeGroup
	if err := json.Unmarshal(dt, &ng); err != nil {
		return nil, err
	}
	return &ng, nil
}

func (t *Txn) Save(ng *NodeGroup) error {
	name, err := ValidateName(ng.Name)
	if err != nil {
		return err
	}
	dt, err := json.Marshal(ng)
	if err != nil {
		return err
	}
	return atomicWriteFile(filepath.Join(t.s.root, "instances", name), dt, 0600)
}

func (t *Txn) Remove(name string) error {
	name, err := ValidateName(name)
	if err != nil {
		return err
	}
	return os.RemoveAll(filepath.Join(t.s.root, "instances", name))
}

func (t *Txn) SetCurrent(key, name string, global, def bool) error {
	c := current{
		Key:    key,
		Name:   name,
		Global: global,
	}
	dt, err := json.Marshal(c)
	if err != nil {
		return err
	}
	if err := atomicWriteFile(filepath.Join(t.s.root, "current"), dt, 0600); err != nil {
		return err
	}

	h := toHash(key)

	if def {
		if err := atomicWriteFile(filepath.Join(t.s.root, "defaults", h), []byte(name), 0600); err != nil {
			return err
		}
	} else {
		os.RemoveAll(filepath.Join(t.s.root, "defaults", h)) // ignore error
	}
	return nil
}

func (t *Txn) reset(key string) error {
	dt, err := json.Marshal(current{Key: key})
	if err != nil {
		return err
	}
	if err := atomicWriteFile(filepath.Join(t.s.root, "current"), dt, 0600); err != nil {
		return err
	}
	return nil
}

func (t *Txn) Current(key string) (*NodeGroup, error) {
	dt, err := ioutil.ReadFile(filepath.Join(t.s.root, "current"))
	if err != nil {
		if !os.IsNotExist(err) {
			return nil, err
		}
	}
	if err == nil {
		var c current
		if err := json.Unmarshal(dt, &c); err != nil {
			return nil, err
		}
		if c.Name != "" {
			if c.Global {
				ng, err := t.NodeGroupByName(c.Name)
				if err == nil {
					return ng, nil
				}
			}

			if c.Key == key {
				ng, err := t.NodeGroupByName(c.Name)
				if err == nil {
					return ng, nil
				}
				return nil, nil
			}
		}
	}

	h := toHash(key)

	dt, err = ioutil.ReadFile(filepath.Join(t.s.root, "defaults", h))
	if err != nil {
		if os.IsNotExist(err) {
			t.reset(key)
			return nil, nil
		}
		return nil, err
	}

	ng, err := t.NodeGroupByName(string(dt))
	if err != nil {
		t.reset(key)
	}
	if err := t.SetCurrent(key, string(dt), false, true); err != nil {
		return nil, err
	}
	return ng, nil
}

type current struct {
	Key    string
	Name   string
	Global bool
}

var namePattern = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9\.\-_]*$`)

func atomicWriteFile(filename string, data []byte, perm os.FileMode) error {
	f, err := ioutil.TempFile(filepath.Dir(filename), ".tmp-"+filepath.Base(filename))
	if err != nil {
		return err
	}
	err = os.Chmod(f.Name(), perm)
	if err != nil {
		f.Close()
		return err
	}
	n, err := f.Write(data)
	if err == nil && n < len(data) {
		f.Close()
		return io.ErrShortWrite
	}
	if err != nil {
		f.Close()
		return err
	}
	if err := f.Sync(); err != nil {
		f.Close()
		return err
	}
	if err := f.Close(); err != nil {
		return err
	}
	return os.Rename(f.Name(), filename)
}

func toHash(in string) string {
	return digest.FromBytes([]byte(in)).Hex()[:20]
}