From 23156f54020f48dd74847eed8393298e7ab6803e Mon Sep 17 00:00:00 2001 From: Brian Goff Date: Fri, 31 Jan 2025 09:49:12 -0800 Subject: [PATCH] Preserve extension fields when marshalling/unmarshalling Signed-off-by: Brian Goff WIP: smarter yamle parser use --- go.mod | 2 +- go.sum | 6 ++-- load.go | 84 ++++++++++++++++++++++++++++++++++++++++++---------- load_test.go | 51 ++++++++++++++++++++++++++++++- spec.go | 60 +++++++++++++++++++++++++++++++++++++ 5 files changed, 184 insertions(+), 19 deletions(-) diff --git a/go.mod b/go.mod index fc69c3b91..138aa6858 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.23.1 require ( github.com/atombender/go-jsonschema v0.17.0 github.com/containerd/platforms v1.0.0-rc.1 - github.com/goccy/go-yaml v1.15.16 + github.com/goccy/go-yaml v1.15.22 github.com/google/go-cmp v0.6.0 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/invopop/jsonschema v0.12.0 diff --git a/go.sum b/go.sum index df716c7db..a60f5887f 100644 --- a/go.sum +++ b/go.sum @@ -75,8 +75,10 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/goccy/go-yaml v1.15.16 h1:PMTVcGI9uNPIn7KLs0H7KC1rE+51yPl5YNh4i8rGuRA= -github.com/goccy/go-yaml v1.15.16/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/goccy/go-yaml v1.15.20 h1:eQHFLrr1lpLYAxupPD9ThZbGtncPl9nyu3nkAayEZgY= +github.com/goccy/go-yaml v1.15.20/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/goccy/go-yaml v1.15.22 h1:iQI1hvCoiYYiVFq76P4AI8ImgDOfgiyKnl/AWjK8/gA= +github.com/goccy/go-yaml v1.15.22/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= diff --git a/load.go b/load.go index 23f4d3f7e..90e369ec8 100644 --- a/load.go +++ b/load.go @@ -7,6 +7,9 @@ import ( "strings" "github.com/goccy/go-yaml" + "github.com/goccy/go-yaml/ast" + "github.com/goccy/go-yaml/parser" + "github.com/goccy/go-yaml/token" "github.com/moby/buildkit/frontend/dockerfile/shell" "github.com/pkg/errors" "golang.org/x/exp/maps" @@ -210,36 +213,87 @@ func (s *Spec) SubstituteArgs(env map[string]string, opts ...SubstituteOpt) erro func LoadSpec(dt []byte) (*Spec, error) { var spec Spec - dt, err := stripXFields(dt) - if err != nil { - return nil, fmt.Errorf("error stripping x-fields: %w", err) - } - if err := yaml.UnmarshalWithOptions(dt, &spec, yaml.Strict()); err != nil { - return nil, fmt.Errorf("error unmarshalling spec: %w", err) + return nil, errors.Wrap(err, "error unmarshalling spec") } if err := spec.Validate(); err != nil { return nil, err } - spec.FillDefaults() + spec.FillDefaults() return &spec, nil } -func stripXFields(dt []byte) ([]byte, error) { - var obj map[string]interface{} - if err := yaml.Unmarshal(dt, &obj); err != nil { - return nil, fmt.Errorf("error unmarshalling spec: %w", err) +func (s *Spec) UnmarshalYAML(dt []byte) error { + parsed, err := parser.ParseBytes(dt, parser.ParseComments) + if err != nil { + return errors.Wrapf(err, "error parsing yaml: \n%s", string(dt)) + } + + if len(parsed.Docs) != 1 { + return errors.New("expected exactly one yaml document") } - for k := range obj { - if strings.HasPrefix(k, "x-") || strings.HasPrefix(k, "X-") { - delete(obj, k) + body := parsed.Docs[0].Body.(*ast.MappingNode) + var extNodes []*ast.MappingValueNode + for i := 0; i < len(body.Values); i++ { + node := body.Values[i] + p := node.GetPath() + if !strings.HasPrefix(p, "$.x-") && !strings.HasPrefix(p, "$.X-") { + continue } + + // Delete the extension node from the AST. + body.Values = append(body.Values[:i], body.Values[i+1:]...) + i-- + extNodes = append(extNodes, node) + } + + parsed, err = parser.ParseBytes([]byte(parsed.String()), parser.ParseComments) + if err != nil { + return errors.Wrapf(err, "error parsing yaml: \n%s", parsed.String()) + } + + type internalSpec Spec + var s2 internalSpec + + dec := yaml.NewDecoder(parsed, yaml.Strict()) + if err := dec.Decode(&s2); err != nil { + return fmt.Errorf("%w:\n\n%s", errors.Wrap(err, "error unmarshalling parsed document"), parsed.String()) + } + + *s = Spec(s2) + + node := ast.Mapping(token.MappingStart("", &token.Position{}), false, extNodes...) + doc := ast.Document(&token.Token{Position: &token.Position{}}, node) + s.extRaw = &ast.File{Docs: []*ast.DocumentNode{doc}} + + return nil +} + +func (s Spec) MarshalYAML() ([]byte, error) { + // We need to define a new type to avoid infinite recursion of MarshalYAML. + type internalSpec Spec + + if s.extRaw == nil { + return yaml.Marshal(internalSpec(s)) + } + + var ext map[string]interface{} + if err := yaml.NewDecoder(s.extRaw).Decode(&ext); err != nil { + return nil, errors.Wrap(err, "error unmarshalling extension nodes") + } + + type specExt struct { + internalSpec `yaml:",inline"` + Ext map[string]interface{} `yaml:",omitempty,inline"` } - return yaml.Marshal(obj) + return yaml.Marshal(specExt{ + internalSpec: internalSpec(s), + Ext: ext, + }) } func (s *BuildStep) processBuildArgs(lex *shell.Lex, args map[string]string, allowArg func(string) bool) error { diff --git a/load_test.go b/load_test.go index c6ea41c07..eab3b1081 100644 --- a/load_test.go +++ b/load_test.go @@ -11,6 +11,7 @@ import ( "slices" "testing" + "github.com/goccy/go-yaml" "github.com/moby/buildkit/frontend/dockerui" "gotest.tools/v3/assert" "gotest.tools/v3/assert/cmp" @@ -508,7 +509,7 @@ X-capitalized-other-field: "some other value capitalized X key" src, ok := spec.Sources["test"] if !ok { - t.Fatal("expected source to be present") + t.Fatalf("expected source to be present: %+v", spec) } if src.Inline == nil { @@ -1404,5 +1405,53 @@ func testSymlinkFillDefaults(t *testing.T) { } }) } +} + +func checkExt[T any](t *testing.T, spec Spec, key string, expect T) { + t.Helper() + + var actual T + err := spec.Ext(key, &actual) + assert.NilError(t, err) + assert.Check(t, cmp.DeepEqual(actual, expect)) +} + +func TestExtensionFieldMarshalUnmarshal(t *testing.T) { + dt := []byte(` +name: test +x-hello: world +x-foo: +- bar +- baz +X-capitalized: world2 +`) + + var spec Spec + err := yaml.Unmarshal(dt, &spec) + assert.NilError(t, err) + + assert.Check(t, cmp.Equal(spec.Name, "test"), spec) + checkExt(t, spec, "hello", "world") + checkExt(t, spec, "x-hello", "world") + checkExt(t, spec, "foo", []string{"bar", "baz"}) + checkExt(t, spec, "x-foo", []string{"bar", "baz"}) + checkExt(t, spec, "capitalized", "world2") + checkExt(t, spec, "X-capitalized", "world2") + + // marshal and unmarshal to ensure the extension fields are preserved + + dt, err = yaml.Marshal(spec) + assert.NilError(t, err) + + var spec2 Spec + err = yaml.Unmarshal(dt, &spec2) + assert.NilError(t, err) + assert.Check(t, cmp.Equal(spec2.Name, "test"), spec2) + checkExt(t, spec2, "hello", "world") + checkExt(t, spec2, "x-hello", "world") + checkExt(t, spec2, "foo", []string{"bar", "baz"}) + checkExt(t, spec2, "x-foo", []string{"bar", "baz"}) + checkExt(t, spec2, "capitalized", "world2") + checkExt(t, spec2, "X-capitalized", "world2") } diff --git a/spec.go b/spec.go index ec35c6e0e..767595068 100644 --- a/spec.go +++ b/spec.go @@ -3,10 +3,14 @@ package dalec import ( "io/fs" + "strings" "time" + "github.com/goccy/go-yaml" + "github.com/goccy/go-yaml/ast" "github.com/moby/buildkit/client/llb" "github.com/opencontainers/go-digest" + "github.com/pkg/errors" ) // Spec is the specification for a package build. @@ -94,6 +98,10 @@ type Spec struct { // Each item in this list is run with a separate rootfs and cannot interact with other tests. // Each [TestSpec] is run with a separate rootfs, asynchronously from other [TestSpec]. Tests []*TestSpec `yaml:"tests,omitempty" json:"tests,omitempty"` + + // extRaw is the raw AST of the extension fields in the spec. + // This is used to extract the ext fields in [Spec.Ext] + extRaw *ast.File } // PatchSpec is used to apply a patch to a source with a given set of options. @@ -431,3 +439,55 @@ func (s *SystemdConfiguration) EnabledUnits() map[string]SystemdUnitConfig { return units } + +type ExtDecodeConfig struct { + AllowUnknownFields bool +} + +// Ext reads the extension field from the spec and unmarshals it into the target +// value. +func (s Spec) Ext(key string, target interface{}, opts ...func(*ExtDecodeConfig)) error { + lookup := key + addPrefix := !strings.HasPrefix(key, "x-") && !strings.HasPrefix(key, "X-") + if addPrefix { + lookup = "x-" + key + } + + p, err := yaml.PathString("$." + lookup) + if err != nil { + return err + } + + node, err := p.FilterFile(s.extRaw) + if err != nil { + if addPrefix { + lookup = "X-" + key + p, err = yaml.PathString("$." + lookup) + if err != nil { + return err + } + node, err = p.FilterFile(s.extRaw) + } + if err != nil { + return errors.Wrap(err, "error filtering node") + } + } + + var cfg ExtDecodeConfig + for _, opt := range opts { + opt(&cfg) + } + + var decodeOpts []yaml.DecodeOption + if !cfg.AllowUnknownFields { + decodeOpts = append(decodeOpts, yaml.Strict()) + } + + dt := node.String() + err = yaml.UnmarshalWithOptions([]byte(dt), target, decodeOpts...) + if err != nil { + return errors.Wrapf(err, "error unmarshalling extension field %q into target", key) + } + + return nil +}