Skip to content

Commit a358f88

Browse files
fix: add structural compatibility for map-to-struct and struct-to-map conversions
Includes test cases and implementation for validating map-to-struct and struct-to-map conversions, ensuring structural compatibility of keys and fields. Signed-off-by: Jakob Möller <[email protected]>
1 parent c831e6b commit a358f88

File tree

2 files changed

+231
-6
lines changed

2 files changed

+231
-6
lines changed

pkg/cel/compatibility.go

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ const (
3737
// - For lists: recursively checks element type compatibility
3838
// - For maps: recursively checks key and value type compatibility
3939
// - For structs: uses DeclTypeProvider to introspect fields and check all required fields exist with compatible types
40+
// - For map → struct and struct → map compatibility if fields/keys are structurally compatible
4041
//
4142
// The provider is required for introspecting struct field information.
4243
// Returns true if types are compatible, false otherwise. If false, the error describes why.
@@ -50,23 +51,28 @@ func AreTypesStructurallyCompatible(output, expected *cel.Type, provider *DeclTy
5051
return true, nil
5152
}
5253

54+
// Unwrap optional output if available
5355
if output.Kind() == cel.OpaqueKind && output.TypeName() == "optional_type" {
5456
return AreTypesStructurallyCompatible(output.Parameters()[0], expected, provider)
5557
}
5658

57-
switch expected.Kind() {
58-
case cel.ListKind:
59+
switch {
60+
case expected.Kind() == cel.StructKind && output.Kind() == cel.MapKind:
61+
return areMapTypesAssignableToStruct(output, expected, provider)
62+
case expected.Kind() == cel.MapKind && output.Kind() == cel.StructKind:
63+
return areStructTypesAssignableToMap(output, expected, provider)
64+
case expected.Kind() == cel.ListKind:
5965
return areListTypesCompatible(output, expected, provider)
60-
case cel.MapKind:
66+
case expected.Kind() == cel.MapKind:
6167
return areMapTypesCompatible(output, expected, provider)
62-
case cel.StructKind:
68+
case expected.Kind() == cel.StructKind:
6369
return areStructTypesCompatible(output, expected, provider)
6470
default:
65-
// Kinds must match
71+
// Kinds must match otherwise
6672
if output.Kind() != expected.Kind() {
6773
return false, fmt.Errorf("type kind mismatch: got %q, expected %q", output.String(), expected.String())
6874
}
69-
// For primitives (int, string, bool, etc.), kind equality is enough
75+
// primitives: kind equality already checked
7076
return true, nil
7177
}
7278
}
@@ -275,3 +281,70 @@ func areStructFieldsCompatible(output, expected *apiservercel.DeclType, provider
275281

276282
return true, nil
277283
}
284+
285+
func areMapTypesAssignableToStruct(outputMap, expectedStruct *cel.Type, provider *DeclTypeProvider) (bool, error) {
286+
expectedDecl := resolveDeclTypeFromPath(expectedStruct.String(), provider)
287+
if expectedDecl == nil || expectedDecl.Fields == nil {
288+
return true, nil
289+
}
290+
291+
// map parameters are [keyType, valueType]
292+
params := outputMap.Parameters()
293+
if len(params) < 2 {
294+
return false, fmt.Errorf("map must have key and value types")
295+
}
296+
297+
keyType := params[0]
298+
valType := params[1]
299+
300+
// keys must be strings to match struct field names
301+
if keyType.Kind() != cel.StringKind {
302+
return false, fmt.Errorf("map keys must be strings to assign to struct")
303+
}
304+
305+
for fieldName, expectedField := range expectedDecl.Fields {
306+
expectedFieldCEL := expectedField.Type.CelType()
307+
if expectedFieldCEL == nil {
308+
continue
309+
}
310+
compatible, err := AreTypesStructurallyCompatible(valType, expectedFieldCEL, provider)
311+
if !compatible {
312+
return false, fmt.Errorf("map value incompatible with struct field %q: %w", fieldName, err)
313+
}
314+
}
315+
316+
return true, nil
317+
}
318+
319+
func areStructTypesAssignableToMap(outputStruct, expectedMap *cel.Type, provider *DeclTypeProvider) (bool, error) {
320+
outputDecl := resolveDeclTypeFromPath(outputStruct.String(), provider)
321+
if outputDecl == nil || outputDecl.Fields == nil {
322+
return true, nil
323+
}
324+
325+
params := expectedMap.Parameters()
326+
if len(params) < 2 {
327+
return false, fmt.Errorf("expected map must have key and value types")
328+
}
329+
keyType := params[0]
330+
valType := params[1]
331+
332+
// struct field names map to string keys
333+
if keyType.Kind() != cel.StringKind {
334+
return false, fmt.Errorf("map key type must be string when assigning struct → map")
335+
}
336+
337+
for fieldName, outputField := range outputDecl.Fields {
338+
outputCEL := outputField.Type.CelType()
339+
if outputCEL == nil {
340+
continue
341+
}
342+
343+
compatible, err := AreTypesStructurallyCompatible(outputCEL, valType, provider)
344+
if !compatible {
345+
return false, fmt.Errorf("struct field %q incompatible with map value type: %w", fieldName, err)
346+
}
347+
}
348+
349+
return true, nil
350+
}

pkg/cel/compatibility_test.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,158 @@ func TestNestedTypes(t *testing.T) {
438438
}
439439
}
440440

441+
func TestMapToStructCompatibility(t *testing.T) {
442+
// homogeneous struct: {a:int, b:int}
443+
intStructFields := map[string]*apiservercel.DeclField{
444+
"a": apiservercel.NewDeclField("a", apiservercel.IntType, true, nil, nil),
445+
"b": apiservercel.NewDeclField("b", apiservercel.IntType, false, nil, nil),
446+
}
447+
intStruct := apiservercel.NewObjectType(TypeNamePrefix+"intStruct", intStructFields)
448+
449+
// homogeneous struct: {x:string, y:string}
450+
stringStructFields := map[string]*apiservercel.DeclField{
451+
"x": apiservercel.NewDeclField("x", apiservercel.StringType, true, nil, nil),
452+
"y": apiservercel.NewDeclField("y", apiservercel.StringType, true, nil, nil),
453+
}
454+
stringStruct := apiservercel.NewObjectType(TypeNamePrefix+"stringStruct", stringStructFields)
455+
456+
provider := NewDeclTypeProvider(intStruct, stringStruct)
457+
458+
tests := []struct {
459+
name string
460+
output *cel.Type
461+
expected *cel.Type
462+
compatible bool
463+
errContains string
464+
}{
465+
{
466+
name: "map[string]int → struct{a:int, b:int} OK",
467+
output: cel.MapType(cel.StringType, cel.IntType),
468+
expected: intStruct.CelType(),
469+
compatible: true,
470+
},
471+
{
472+
name: "map[string]optional<int> → struct{a:int, b:int} OK",
473+
output: cel.MapType(cel.StringType, cel.OptionalType(cel.IntType)),
474+
expected: intStruct.CelType(),
475+
compatible: true,
476+
},
477+
{
478+
name: "map[string]string → struct{x:string, y:string} OK",
479+
output: cel.MapType(cel.StringType, cel.StringType),
480+
expected: stringStruct.CelType(),
481+
compatible: true,
482+
},
483+
{
484+
name: "map[string]dyn → struct{x:string, y:string} OK",
485+
output: cel.MapType(cel.StringType, cel.DynType),
486+
expected: stringStruct.CelType(),
487+
compatible: true,
488+
},
489+
{
490+
name: "map[int]string → struct{x:string,y:string} invalid (key type)",
491+
output: cel.MapType(cel.IntType, cel.StringType),
492+
expected: stringStruct.CelType(),
493+
compatible: false,
494+
errContains: "keys must be strings",
495+
},
496+
{
497+
name: "map[string]int → struct{x:string,y:string} incompatible",
498+
output: cel.MapType(cel.StringType, cel.IntType),
499+
expected: stringStruct.CelType(),
500+
compatible: false,
501+
errContains: "incompatible",
502+
},
503+
}
504+
505+
for _, tt := range tests {
506+
t.Run(tt.name, func(t *testing.T) {
507+
compatible, err := AreTypesStructurallyCompatible(tt.output, tt.expected, provider)
508+
assert.Equal(t, tt.compatible, compatible)
509+
if tt.compatible {
510+
assert.NoError(t, err)
511+
} else {
512+
require.Error(t, err)
513+
if tt.errContains != "" {
514+
assert.Contains(t, err.Error(), tt.errContains)
515+
}
516+
}
517+
})
518+
}
519+
}
520+
521+
func TestStructToMapCompatibility(t *testing.T) {
522+
intStructFields := map[string]*apiservercel.DeclField{
523+
"a": apiservercel.NewDeclField("a", apiservercel.IntType, true, nil, nil),
524+
"b": apiservercel.NewDeclField("b", apiservercel.IntType, false, nil, nil),
525+
}
526+
intStruct := apiservercel.NewObjectType(TypeNamePrefix+"intStruct", intStructFields)
527+
528+
stringStructFields := map[string]*apiservercel.DeclField{
529+
"x": apiservercel.NewDeclField("x", apiservercel.StringType, true, nil, nil),
530+
"y": apiservercel.NewDeclField("y", apiservercel.StringType, true, nil, nil),
531+
}
532+
stringStruct := apiservercel.NewObjectType(TypeNamePrefix+"stringStruct", stringStructFields)
533+
534+
provider := NewDeclTypeProvider(intStruct, stringStruct)
535+
536+
tests := []struct {
537+
name string
538+
output *cel.Type
539+
expected *cel.Type
540+
compatible bool
541+
errContains string
542+
}{
543+
{
544+
name: "struct{a:int,b:int} → map[string]int OK",
545+
output: intStruct.CelType(),
546+
expected: cel.MapType(cel.StringType, cel.IntType),
547+
compatible: true,
548+
},
549+
{
550+
name: "struct{x:string,y:string} → map[string]string OK",
551+
output: stringStruct.CelType(),
552+
expected: cel.MapType(cel.StringType, cel.StringType),
553+
compatible: true,
554+
},
555+
{
556+
name: "struct{x:string,y:string} → map[string]dyn OK",
557+
output: stringStruct.CelType(),
558+
expected: cel.MapType(cel.StringType, cel.DynType),
559+
compatible: true,
560+
},
561+
{
562+
name: "struct{x:string,y:string} → map[int]string invalid (bad key)",
563+
output: stringStruct.CelType(),
564+
expected: cel.MapType(cel.IntType, cel.StringType),
565+
compatible: false,
566+
errContains: "key type must be string",
567+
},
568+
{
569+
name: "struct{x:string,y:string} → map[string]int incompatible",
570+
output: stringStruct.CelType(),
571+
expected: cel.MapType(cel.StringType, cel.IntType),
572+
compatible: false,
573+
errContains: "incompatible",
574+
},
575+
}
576+
577+
for _, tt := range tests {
578+
t.Run(tt.name, func(t *testing.T) {
579+
compatible, err := AreTypesStructurallyCompatible(tt.output, tt.expected, provider)
580+
assert.Equal(t, tt.compatible, compatible)
581+
if tt.compatible {
582+
assert.NoError(t, err)
583+
} else {
584+
require.Error(t, err)
585+
if tt.errContains != "" {
586+
assert.Contains(t, err.Error(), tt.errContains)
587+
}
588+
}
589+
})
590+
}
591+
}
592+
441593
func TestOptionalPrimitive(t *testing.T) {
442594
optionalString := cel.OpaqueType("optional_type", cel.StringType)
443595
optionalInt := cel.OpaqueType("optional_type", cel.IntType)

0 commit comments

Comments
 (0)