-
Notifications
You must be signed in to change notification settings - Fork 135
Expand file tree
/
Copy pathselect_fuzz_test.go
More file actions
498 lines (442 loc) · 15 KB
/
select_fuzz_test.go
File metadata and controls
498 lines (442 loc) · 15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
package sqlbuilder
import (
"fmt"
"math/rand"
"reflect"
"slices"
"sync"
"testing"
)
type fuzzState struct {
data []byte
dataIndex int
callchainRepresentation string
currentBuilder reflect.Value
usedMethods map[string]bool
}
func (fs *fuzzState) consumeData(size int) []byte {
if len(fs.data) <= fs.dataIndex+size {
return []byte{}
}
result := make([]byte, size)
copy(result, fs.data[fs.dataIndex:fs.dataIndex+size])
fs.dataIndex += size
return result
}
func (fs *fuzzState) updateCallchain(method string, args []reflect.Value) {
fs.callchainRepresentation += "." + method + "("
for i, arg := range args {
if i > 0 {
fs.callchainRepresentation += ", "
}
fs.callchainRepresentation += fmt.Sprintf("%q", arg)
}
fs.callchainRepresentation += ")"
}
func getSelectBuilderMethods() (map[string]reflect.Type, []string) {
sb := NewSelectBuilder()
sbType := reflect.TypeOf(sb)
// Skip methods that are likely to cause issues or don't return builders
skipMethods := []string{
"Build", "String", "BuildWithFlavor", "Flavor",
"NumCol", "NumValue", "NumAssignment", "TableNames", "Var",
}
methodList := make(map[string]reflect.Type)
methodNames := make([]string, 0, sbType.NumMethod())
for i := 0; i < sbType.NumMethod(); i++ {
method := sbType.Method(i)
if slices.Contains(skipMethods, method.Name) {
continue
}
methodList[method.Name] = method.Type
methodNames = append(methodNames, method.Name)
}
return methodList, methodNames
}
func generateMethodArgs(methodType reflect.Type, state *fuzzState) ([]reflect.Value, bool) {
numArgs := methodType.NumIn() - 1 // Skip receiver
isVariadic := methodType.IsVariadic()
if isVariadic {
return generateVariadicArgs(methodType, numArgs, state)
}
return generateFixedArgs(methodType, numArgs, state)
}
func generateFixedArgs(methodType reflect.Type, numArgs int, state *fuzzState) ([]reflect.Value, bool) {
args := make([]reflect.Value, numArgs)
for i := 0; i < numArgs; i++ {
argType := methodType.In(i + 1) // Skip receiver
argData := state.consumeData(16)
args[i] = generateArgumentForType(argType, argData)
if !args[i].IsValid() {
return nil, false
}
// Additional type compatibility check for complex types
if argType.Kind() == reflect.Ptr && args[i].Kind() == reflect.Ptr {
if argType != args[i].Type() {
return nil, false
}
}
}
return args, true
}
func generateVariadicArgs(methodType reflect.Type, numArgs int, state *fuzzState) ([]reflect.Value, bool) {
numFixedArgs := numArgs - 1 // Last parameter is the variadic slice
// Generate fixed arguments first
args := make([]reflect.Value, numFixedArgs)
for i := 0; i < numFixedArgs; i++ {
argType := methodType.In(i + 1) // Skip receiver
argData := state.consumeData(16)
args[i] = generateArgumentForType(argType, argData)
if !args[i].IsValid() {
return nil, false
}
}
// Generate variadic arguments (0-3 arguments to keep it reasonable)
if numFixedArgs < numArgs {
variadicType := methodType.In(numArgs).Elem() // Get the element type of the slice
numVariadicArgs := 0
if len(state.data) > state.dataIndex {
// 0-3 variadic args, keep the number small while still exercising multiple values
// TODO: Possible fuzz improvement to allow for more variadic args. Not sure it's worth it.
numVariadicArgs = int(state.data[state.dataIndex] % 4)
state.dataIndex++
}
for j := 0; j < numVariadicArgs; j++ {
argData := state.consumeData(16)
varArg := generateArgumentForType(variadicType, argData)
if !varArg.IsValid() {
return nil, false
}
args = append(args, varArg)
}
}
return args, true
}
func tryCallMethod(methodName string, methodType reflect.Type, state *fuzzState, t *testing.T) bool {
// Check if method exists on current builder
callableMethod := state.currentBuilder.MethodByName(methodName)
if !callableMethod.IsValid() {
return false
}
// Generate arguments
args, canCall := generateMethodArgs(methodType, state)
if !canCall {
return false
}
// Update call chain representation and log it
state.updateCallchain(methodName, args)
t.Log("callchain:", state.callchainRepresentation)
// Mark this method as used
state.usedMethods[methodName] = true
// Call method and capture result for chaining
result := callableMethod.Call(args)
// Only chain if method returns the same builder type (SelectBuilder)
if len(result) > 0 && result[0].IsValid() {
resultType := result[0].Type()
if resultType.Kind() == reflect.Ptr &&
resultType.String() == "*sqlbuilder.SelectBuilder" &&
!result[0].IsNil() {
state.currentBuilder = result[0]
}
}
return true
}
func executeMethodChain(methodList map[string]reflect.Type, methodNames []string, state *fuzzState, maxChains uint8, t *testing.T) {
for nbFunc := uint8(0); nbFunc < maxChains; nbFunc++ {
methodCalled := false
// Try to find a method we haven't used yet to create more varied chains
for _, method := range methodNames {
// Skip methods we've already used to create more variety
if state.usedMethods[method] && len(state.usedMethods) < len(methodNames) {
continue
}
if tryCallMethod(method, methodList[method], state, t) {
methodCalled = true
break // Move to next chain iteration
}
}
// If no method could be called, break the chain
if !methodCalled {
break
}
}
}
func finalizeBuild(state *fuzzState) {
// Always try to build the final SQL to ensure it doesn't panic
if state.currentBuilder.IsValid() {
buildMethod := state.currentBuilder.MethodByName("Build")
if buildMethod.IsValid() {
buildMethod.Call([]reflect.Value{})
}
}
}
func FuzzSelect(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte, seed int64, numberOfChainedFunction uint8) {
if len(data) == 0 {
return
}
// Get all available methods for SelectBuilder
methodList, methodNames := getSelectBuilderMethods()
// Randomize method order deterministically based on seed
r := rand.New(rand.NewSource(seed))
r.Shuffle(len(methodNames), func(i, j int) {
methodNames[i], methodNames[j] = methodNames[j], methodNames[i]
})
// Initialize fuzzing state
state := &fuzzState{
data: data,
dataIndex: 0,
callchainRepresentation: "NewSelectBuilder()",
currentBuilder: reflect.ValueOf(NewSelectBuilder()),
usedMethods: make(map[string]bool),
}
// Limit the number of chained functions to prevent infinite loops
maxChains := numberOfChainedFunction
if maxChains > 10 {
maxChains = 10
}
// Execute method chain
executeMethodChain(methodList, methodNames, state, maxChains, t)
t.Logf("Final callchain: %s", state.callchainRepresentation)
// Try to build the final result
finalizeBuild(state)
})
}
// generateArgumentForType generates a reflect.Value for the given type based on the provided data.
// It will consume the data slice to create a value of the specified type.
// It handles specific custom types like JoinOption and Flavor, and Go will consider them disntinct types than their aliases.
func generateArgumentForType(argType reflect.Type, data []byte) reflect.Value {
switch argType.Kind() {
case reflect.String:
// Handle specific custom string types first
if argType.String() == "sqlbuilder.JoinOption" {
joinOptions := []JoinOption{
FullJoin, FullOuterJoin, InnerJoin,
LeftJoin, LeftOuterJoin, RightJoin, RightOuterJoin,
}
if len(data) > 0 {
return reflect.ValueOf(joinOptions[int(data[0])%len(joinOptions)])
}
return reflect.ValueOf(InnerJoin)
}
// Use remaining data as string
return reflect.ValueOf(string(data))
case reflect.Int:
// Handle specific custom int types first
if argType.String() == "sqlbuilder.Flavor" {
return reflect.ValueOf(DefaultFlavor)
}
if len(data) > 0 {
return reflect.ValueOf(int(data[0]))
}
return reflect.ValueOf(0)
case reflect.Bool:
if len(data) > 0 {
return reflect.ValueOf(data[0]%2 == 0)
}
return reflect.ValueOf(false)
case reflect.Int8:
if len(data) > 0 {
return reflect.ValueOf(int8(data[0]))
}
return reflect.ValueOf(int8(0))
case reflect.Int16:
if len(data) >= 2 {
return reflect.ValueOf(int16(data[0])<<8 | int16(data[1]))
}
return reflect.ValueOf(int16(0))
case reflect.Int32:
if len(data) >= 4 {
return reflect.ValueOf(int32(data[0])<<24 | int32(data[1])<<16 | int32(data[2])<<8 | int32(data[3]))
}
return reflect.ValueOf(int32(0))
case reflect.Int64:
if len(data) >= 8 {
val := int64(data[0])<<56 | int64(data[1])<<48 | int64(data[2])<<40 | int64(data[3])<<32 |
int64(data[4])<<24 | int64(data[5])<<16 | int64(data[6])<<8 | int64(data[7])
return reflect.ValueOf(val)
}
return reflect.ValueOf(int64(0))
case reflect.Uint:
if len(data) > 0 {
return reflect.ValueOf(uint(data[0]))
}
return reflect.ValueOf(uint(0))
case reflect.Uint8:
if len(data) > 0 {
return reflect.ValueOf(uint8(data[0]))
}
return reflect.ValueOf(uint8(0))
case reflect.Uint16:
if len(data) >= 2 {
return reflect.ValueOf(uint16(data[0])<<8 | uint16(data[1]))
}
return reflect.ValueOf(uint16(0))
case reflect.Uint32:
if len(data) >= 4 {
return reflect.ValueOf(uint32(data[0])<<24 | uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]))
}
return reflect.ValueOf(uint32(0))
case reflect.Uint64:
if len(data) >= 8 {
val := uint64(data[0])<<56 | uint64(data[1])<<48 | uint64(data[2])<<40 | uint64(data[3])<<32 |
uint64(data[4])<<24 | uint64(data[5])<<16 | uint64(data[6])<<8 | uint64(data[7])
return reflect.ValueOf(val)
}
return reflect.ValueOf(uint64(0))
case reflect.Float32:
if len(data) >= 4 {
bits := uint32(data[0])<<24 | uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
return reflect.ValueOf(float32(bits))
}
return reflect.ValueOf(float32(0))
case reflect.Float64:
if len(data) >= 8 {
bits := uint64(data[0])<<56 | uint64(data[1])<<48 | uint64(data[2])<<40 | uint64(data[3])<<32 |
uint64(data[4])<<24 | uint64(data[5])<<16 | uint64(data[6])<<8 | uint64(data[7])
return reflect.ValueOf(float64(bits))
}
return reflect.ValueOf(float64(0))
case reflect.Slice:
if argType.Elem().Kind() == reflect.String {
return reflect.ValueOf([]string{string(data)})
}
if argType.Elem().Kind() == reflect.Interface {
return reflect.ValueOf([]interface{}{string(data)})
}
return reflect.ValueOf([]interface{}{string(data)})
case reflect.Ptr:
// Handle pointer types by creating a pointer to the underlying type
// Handle specific pointer types
if argType == reflect.TypeOf((*WhereClause)(nil)) {
return reflect.ValueOf(NewWhereClause())
}
if argType == reflect.TypeOf((*SelectBuilder)(nil)) {
return reflect.ValueOf(NewSelectBuilder())
}
if argType == reflect.TypeOf((*Args)(nil)) {
return reflect.ValueOf(&Args{})
}
if argType == reflect.TypeOf((*CTEBuilder)(nil)) {
return reflect.ValueOf(DefaultFlavor.NewCTEBuilder())
}
if argType == reflect.TypeOf((*InsertBuilder)(nil)) {
return reflect.ValueOf(DefaultFlavor.NewInsertBuilder())
}
if argType == reflect.TypeOf((*UpdateBuilder)(nil)) {
return reflect.ValueOf(DefaultFlavor.NewUpdateBuilder())
}
if argType == reflect.TypeOf((*DeleteBuilder)(nil)) {
return reflect.ValueOf(DefaultFlavor.NewDeleteBuilder())
}
// For other pointer types, create a pointer to the underlying type
str := string(data)
return reflect.ValueOf(&str)
case reflect.Interface:
// Handle specific interface types
if argType.String() == "sqlbuilder.Builder" {
// Create a simple SelectBuilder for Builder interface
return reflect.ValueOf(NewSelectBuilder())
}
return reflect.ValueOf(string(data))
default:
// For other types, use zero value
return reflect.Zero(argType)
}
}
// FuzzSelectClone fuzzes SelectBuilder.Clone behavior under concurrent usage
// and ensures cloned instances are independent and safe to mutate.
func FuzzSelectClone(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte, seed int64, numberOfChainedFunction uint8) {
if len(data) == 0 {
return
}
methodList, methodNames := getSelectBuilderMethods()
r := rand.New(rand.NewSource(seed))
r.Shuffle(len(methodNames), func(i, j int) {
methodNames[i], methodNames[j] = methodNames[j], methodNames[i]
})
// Build a base template SelectBuilder via fuzzed method chains.
base := NewSelectBuilder()
baseState := &fuzzState{
data: data,
dataIndex: 0,
callchainRepresentation: "NewSelectBuilder()",
currentBuilder: reflect.ValueOf(base),
usedMethods: make(map[string]bool),
}
maxChains := numberOfChainedFunction
if maxChains > 10 {
maxChains = 10
}
executeMethodChain(methodList, methodNames, baseState, maxChains, t)
baseSQLBefore, baseArgsBefore := base.Build()
// Clone concurrently and mutate clones with fuzzed chains.
cloneCount := int(r.Uint32()%4) + 1 // 1..4 clones
var wg sync.WaitGroup
wg.Add(cloneCount)
start := make(chan struct{})
type result struct {
sql string
args []interface{}
}
results := make(chan result, cloneCount)
for i := 0; i < cloneCount; i++ {
// Use different offsets into the same fuzz data for variety.
offset := 0
if len(data) > 0 {
offset = (i * 17) % len(data)
}
go func(off int) {
defer wg.Done()
<-start // start all goroutines roughly at the same time
c := base.Clone()
st := &fuzzState{
data: data,
dataIndex: off,
callchainRepresentation: "Clone()",
currentBuilder: reflect.ValueOf(c),
usedMethods: make(map[string]bool),
}
executeMethodChain(methodList, methodNames, st, maxChains, t)
finalizeBuild(st) // ensure no panic on Build
s, a := c.Build()
results <- result{sql: s, args: a}
}(offset)
}
close(start)
wg.Wait()
close(results)
// Ensure base builder stays unchanged after concurrent cloning/mutation of clones.
baseSQLAfter, baseArgsAfter := base.Build()
if baseSQLBefore != baseSQLAfter || !reflect.DeepEqual(baseArgsBefore, baseArgsAfter) {
t.Fatalf("base builder mutated by clones:\n before: %s %v\n after: %s %v", baseSQLBefore, baseArgsBefore, baseSQLAfter, baseArgsAfter)
}
// Independence check: mutating one clone does not affect another clone.
cloneA := base.Clone()
sA1, aA1 := cloneA.Build()
done := make(chan struct{})
go func() {
defer close(done)
c2 := base.Clone()
// Apply a deterministic small change; should not affect cloneA.
c2.OrderBy("id").Desc().Limit(1).Offset(0)
_, _ = c2.Build()
}()
sA2, aA2 := cloneA.Build()
if sA1 != sA2 || !reflect.DeepEqual(aA1, aA2) {
t.Fatalf("cloneA changed after mutating another clone")
}
<-done
// Further independence: modifying cloneA should not affect the base.
cloneA.Limit(3).Asc()
_ = cloneA.String()
baseSQLFinal, baseArgsFinal := base.Build()
if baseSQLFinal != baseSQLAfter || !reflect.DeepEqual(baseArgsFinal, baseArgsAfter) {
t.Fatalf("base changed after modifying a clone")
}
// Drain results to ensure all builds completed; mainly to use the values and avoid lints.
for range results {
}
})
}