bake: allow interception to create derived contexts

This patch allows high level clients to define an EvalContext method
which can derive a new context given a block and the base parent
context.

This allows users of the package to intercept evaluation before it
begins, and define additional variables and functions that are bound to
a single block.

Signed-off-by: Justin Chadwell <me@jedevc.com>
pull/1690/head
Justin Chadwell 2 years ago
parent 624bc064d8
commit 1613fde55c

@ -51,6 +51,7 @@ type parser struct {
blocks map[string]map[string][]*hcl.Block blocks map[string]map[string][]*hcl.Block
blockValues map[*hcl.Block]reflect.Value blockValues map[*hcl.Block]reflect.Value
blockEvalCtx map[*hcl.Block]*hcl.EvalContext
blockTypes map[string]reflect.Type blockTypes map[string]reflect.Type
ectx *hcl.EvalContext ectx *hcl.EvalContext
@ -58,20 +59,19 @@ type parser struct {
progress map[string]struct{} progress map[string]struct{}
progressF map[string]struct{} progressF map[string]struct{}
progressB map[*hcl.Block]map[string]struct{} progressB map[*hcl.Block]map[string]struct{}
doneF map[string]struct{}
doneB map[*hcl.Block]map[string]struct{} doneB map[*hcl.Block]map[string]struct{}
} }
var errUndefined = errors.New("undefined") var errUndefined = errors.New("undefined")
func (p *parser) loadDeps(exp hcl.Expression, exclude map[string]struct{}, allowMissing bool) hcl.Diagnostics { func (p *parser) loadDeps(ectx *hcl.EvalContext, exp hcl.Expression, exclude map[string]struct{}, allowMissing bool) hcl.Diagnostics {
fns, hcldiags := funcCalls(exp) fns, hcldiags := funcCalls(exp)
if hcldiags.HasErrors() { if hcldiags.HasErrors() {
return hcldiags return hcldiags
} }
for _, fn := range fns { for _, fn := range fns {
if err := p.resolveFunction(fn); err != nil { if err := p.resolveFunction(ectx, fn); err != nil {
if allowMissing && errors.Is(err, errUndefined) { if allowMissing && errors.Is(err, errUndefined) {
continue continue
} }
@ -131,7 +131,7 @@ func (p *parser) loadDeps(exp hcl.Expression, exclude map[string]struct{}, allow
return wrapErrorDiagnostic("Invalid expression", err, exp.Range().Ptr(), exp.Range().Ptr()) return wrapErrorDiagnostic("Invalid expression", err, exp.Range().Ptr(), exp.Range().Ptr())
} }
} else { } else {
if err := p.resolveValue(v.RootName()); err != nil { if err := p.resolveValue(ectx, v.RootName()); err != nil {
if allowMissing && errors.Is(err, errUndefined) { if allowMissing && errors.Is(err, errUndefined) {
continue continue
} }
@ -145,16 +145,16 @@ func (p *parser) loadDeps(exp hcl.Expression, exclude map[string]struct{}, allow
// resolveFunction forces evaluation of a function, storing the result into the // resolveFunction forces evaluation of a function, storing the result into the
// parser. // parser.
func (p *parser) resolveFunction(name string) error { func (p *parser) resolveFunction(ectx *hcl.EvalContext, name string) error {
if _, ok := p.doneF[name]; ok { if _, ok := p.ectx.Functions[name]; ok {
return nil return nil
} }
f, ok := p.funcs[name] if _, ok := ectx.Functions[name]; ok {
if !ok {
if _, ok := p.ectx.Functions[name]; ok {
return nil return nil
} }
return errors.Wrapf(errUndefined, "function %q does not exit", name) f, ok := p.funcs[name]
if !ok {
return errors.Wrapf(errUndefined, "function %q does not exist", name)
} }
if _, ok := p.progressF[name]; ok { if _, ok := p.progressF[name]; ok {
return errors.Errorf("function cycle not allowed for %s", name) return errors.Errorf("function cycle not allowed for %s", name)
@ -204,7 +204,7 @@ func (p *parser) resolveFunction(name string) error {
return diags return diags
} }
if diags := p.loadDeps(f.Result.Expr, params, false); diags.HasErrors() { if diags := p.loadDeps(p.ectx, f.Result.Expr, params, false); diags.HasErrors() {
return diags return diags
} }
@ -214,7 +214,6 @@ func (p *parser) resolveFunction(name string) error {
if diags.HasErrors() { if diags.HasErrors() {
return diags return diags
} }
p.doneF[name] = struct{}{}
p.ectx.Functions[name] = v p.ectx.Functions[name] = v
return nil return nil
@ -222,10 +221,13 @@ func (p *parser) resolveFunction(name string) error {
// resolveValue forces evaluation of a named value, storing the result into the // resolveValue forces evaluation of a named value, storing the result into the
// parser. // parser.
func (p *parser) resolveValue(name string) (err error) { func (p *parser) resolveValue(ectx *hcl.EvalContext, name string) (err error) {
if _, ok := p.ectx.Variables[name]; ok { if _, ok := p.ectx.Variables[name]; ok {
return nil return nil
} }
if _, ok := ectx.Variables[name]; ok {
return nil
}
if _, ok := p.progress[name]; ok { if _, ok := p.progress[name]; ok {
return errors.Errorf("variable cycle not allowed for %s", name) return errors.Errorf("variable cycle not allowed for %s", name)
} }
@ -242,9 +244,10 @@ func (p *parser) resolveValue(name string) (err error) {
if _, builtin := p.opt.Vars[name]; !ok && !builtin { if _, builtin := p.opt.Vars[name]; !ok && !builtin {
vr, ok := p.vars[name] vr, ok := p.vars[name]
if !ok { if !ok {
return errors.Wrapf(errUndefined, "variable %q does not exit", name) return errors.Wrapf(errUndefined, "variable %q does not exist", name)
} }
def = vr.Default def = vr.Default
ectx = p.ectx
} }
if def == nil { if def == nil {
@ -257,10 +260,10 @@ func (p *parser) resolveValue(name string) (err error) {
return return
} }
if diags := p.loadDeps(def.Expr, nil, true); diags.HasErrors() { if diags := p.loadDeps(ectx, def.Expr, nil, true); diags.HasErrors() {
return diags return diags
} }
vv, diags := def.Expr.Value(p.ectx) vv, diags := def.Expr.Value(ectx)
if diags.HasErrors() { if diags.HasErrors() {
return diags return diags
} }
@ -364,18 +367,43 @@ func (p *parser) resolveBlock(block *hcl.Block, target *hcl.BodySchema) (err err
return FilterExcludeBody(block.Body, filter) return FilterExcludeBody(block.Body, filter)
} }
// load dependencies from all targeted properties // prepare the output destination and evaluation context
t, ok := p.blockTypes[block.Type] t, ok := p.blockTypes[block.Type]
if !ok { if !ok {
return nil return nil
} }
var output reflect.Value
var ectx *hcl.EvalContext
if prev, ok := p.blockValues[block]; ok {
output = prev
ectx = p.blockEvalCtx[block]
} else {
output = reflect.New(t)
setLabel(output, block.Labels[0]) // early attach labels, so we can reference them
type ectxI interface {
EvalContext(base *hcl.EvalContext, block *hcl.Block) *hcl.EvalContext
}
if v, ok := output.Interface().(ectxI); ok {
ectx = v.EvalContext(p.ectx, block)
if ectx != p.ectx && ectx.Parent() != p.ectx {
return errors.Errorf("EvalContext must return a context with the correct parent")
}
} else {
ectx = p.ectx
}
}
p.blockValues[block] = output
p.blockEvalCtx[block] = ectx
// load dependencies from all targeted properties
schema, _ := gohcl.ImpliedBodySchema(reflect.New(t).Interface()) schema, _ := gohcl.ImpliedBodySchema(reflect.New(t).Interface())
content, _, diag := body().PartialContent(schema) content, _, diag := body().PartialContent(schema)
if diag.HasErrors() { if diag.HasErrors() {
return diag return diag
} }
for _, a := range content.Attributes { for _, a := range content.Attributes {
diag := p.loadDeps(a.Expr, nil, true) diag := p.loadDeps(ectx, a.Expr, nil, true)
if diag.HasErrors() { if diag.HasErrors() {
return diag return diag
} }
@ -388,18 +416,10 @@ func (p *parser) resolveBlock(block *hcl.Block, target *hcl.BodySchema) (err err
} }
// decode! // decode!
var output reflect.Value diag = gohcl.DecodeBody(body(), ectx, output.Interface())
if prev, ok := p.blockValues[block]; ok {
output = prev
} else {
output = reflect.New(t)
setLabel(output, block.Labels[0]) // early attach labels, so we can reference them
}
diag = gohcl.DecodeBody(body(), p.ectx, output.Interface())
if diag.HasErrors() { if diag.HasErrors() {
return diag return diag
} }
p.blockValues[block] = output
// mark all targeted properties as done // mark all targeted properties as done
for _, a := range content.Attributes { for _, a := range content.Attributes {
@ -417,7 +437,7 @@ func (p *parser) resolveBlock(block *hcl.Block, target *hcl.BodySchema) (err err
} }
} }
// store the result into the evaluation context (so if can be referenced) // store the result into the evaluation context (so it can be referenced)
outputType, err := gocty.ImpliedType(output.Interface()) outputType, err := gocty.ImpliedType(output.Interface())
if err != nil { if err != nil {
return err return err
@ -477,18 +497,18 @@ func Parse(b hcl.Body, opt Opt, val interface{}) hcl.Diagnostics {
blocks: map[string]map[string][]*hcl.Block{}, blocks: map[string]map[string][]*hcl.Block{},
blockValues: map[*hcl.Block]reflect.Value{}, blockValues: map[*hcl.Block]reflect.Value{},
blockEvalCtx: map[*hcl.Block]*hcl.EvalContext{},
blockTypes: map[string]reflect.Type{}, blockTypes: map[string]reflect.Type{},
ectx: &hcl.EvalContext{
Variables: map[string]cty.Value{},
Functions: Stdlib(),
},
progress: map[string]struct{}{}, progress: map[string]struct{}{},
progressF: map[string]struct{}{}, progressF: map[string]struct{}{},
progressB: map[*hcl.Block]map[string]struct{}{}, progressB: map[*hcl.Block]map[string]struct{}{},
doneF: map[string]struct{}{},
doneB: map[*hcl.Block]map[string]struct{}{}, doneB: map[*hcl.Block]map[string]struct{}{},
ectx: &hcl.EvalContext{
Variables: map[string]cty.Value{},
Functions: stdlibFunctions,
},
} }
for _, v := range defs.Variables { for _, v := range defs.Variables {
@ -532,7 +552,7 @@ func Parse(b hcl.Body, opt Opt, val interface{}) hcl.Diagnostics {
delete(p.attrs, "function") delete(p.attrs, "function")
for k := range p.opt.Vars { for k := range p.opt.Vars {
_ = p.resolveValue(k) _ = p.resolveValue(p.ectx, k)
} }
for _, a := range content.Attributes { for _, a := range content.Attributes {
@ -548,7 +568,7 @@ func Parse(b hcl.Body, opt Opt, val interface{}) hcl.Diagnostics {
} }
for k := range p.vars { for k := range p.vars {
if err := p.resolveValue(k); err != nil { if err := p.resolveValue(p.ectx, k); err != nil {
if diags, ok := err.(hcl.Diagnostics); ok { if diags, ok := err.(hcl.Diagnostics); ok {
return diags return diags
} }
@ -558,7 +578,7 @@ func Parse(b hcl.Body, opt Opt, val interface{}) hcl.Diagnostics {
} }
for k := range p.funcs { for k := range p.funcs {
if err := p.resolveFunction(k); err != nil { if err := p.resolveFunction(p.ectx, k); err != nil {
if diags, ok := err.(hcl.Diagnostics); ok { if diags, ok := err.(hcl.Diagnostics); ok {
return diags return diags
} }
@ -678,7 +698,7 @@ func Parse(b hcl.Body, opt Opt, val interface{}) hcl.Diagnostics {
} }
for k := range p.attrs { for k := range p.attrs {
if err := p.resolveValue(k); err != nil { if err := p.resolveValue(p.ectx, k); err != nil {
if diags, ok := err.(hcl.Diagnostics); ok { if diags, ok := err.(hcl.Diagnostics); ok {
return diags return diags
} }

@ -124,3 +124,11 @@ var timestampFunc = function.New(&function.Spec{
return cty.StringVal(time.Now().UTC().Format(time.RFC3339)), nil return cty.StringVal(time.Now().UTC().Format(time.RFC3339)), nil
}, },
}) })
func Stdlib() map[string]function.Function {
funcs := make(map[string]function.Function, len(stdlibFunctions))
for k, v := range stdlibFunctions {
funcs[k] = v
}
return funcs
}

Loading…
Cancel
Save