Skip to content

Commit f31fc0b

Browse files
committed
fix queryResult
1 parent 403d433 commit f31fc0b

8 files changed

+97
-50
lines changed

adapters/clickhouse_adapter.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"fmt"
2020
"io"
2121
"log/slog"
22+
"reflect"
2223
"strings"
2324

2425
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
@@ -229,18 +230,28 @@ func (a *ClickhouseDbqDataSourceAdapter) InterpretDataQualityCheck(check *dbqcor
229230
return sqlQuery, nil
230231
}
231232

232-
func (a *ClickhouseDbqDataSourceAdapter) ExecuteQuery(ctx context.Context, query string) (string, error) {
233+
func (a *ClickhouseDbqDataSourceAdapter) ExecuteQuery(ctx context.Context, query string) (interface{}, error) {
233234
rows, err := a.cnn.Query(ctx, query)
234235
if err != nil {
235236
return "", fmt.Errorf("failed to execute query for check: %v", err)
236237
}
237238
defer rows.Close()
238239

239-
var queryResult string
240+
var queryResult interface{}
240241
for rows.Next() {
241-
if err := rows.Scan(&queryResult); err != nil {
242-
return "", fmt.Errorf("failed to scan result for check: %v", err)
242+
scanArgs := make([]interface{}, len(rows.Columns()))
243+
for i, colType := range rows.ColumnTypes() {
244+
scanType := colType.ScanType()
245+
valuePtr := reflect.New(scanType).Interface()
246+
scanArgs[i] = valuePtr
243247
}
248+
249+
err = rows.Scan(scanArgs...)
250+
if err != nil {
251+
return "", err
252+
}
253+
254+
queryResult = reflect.ValueOf(scanArgs[0]).Elem().Interface()
244255
}
245256

246257
if err = rows.Err(); err != nil {

adapters/mysql_adapter.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,14 @@ func (a *MysqlDbqDataSourceAdapter) InterpretDataQualityCheck(check *dbqcore.Dat
229229
return sqlQuery, nil
230230
}
231231

232-
func (a *MysqlDbqDataSourceAdapter) ExecuteQuery(ctx context.Context, query string) (string, error) {
232+
func (a *MysqlDbqDataSourceAdapter) ExecuteQuery(ctx context.Context, query string) (interface{}, error) {
233233
rows, err := a.db.QueryContext(ctx, query)
234234
if err != nil {
235235
return "", fmt.Errorf("failed to execute query for check: %v", err)
236236
}
237237
defer rows.Close()
238238

239-
var queryResult string
239+
var queryResult interface{}
240240
for rows.Next() {
241241
if err := rows.Scan(&queryResult); err != nil {
242242
return "", fmt.Errorf("failed to scan result for check: %v", err)

adapters/postgresql_adapter.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,14 @@ func (a *PostgresqlDbqDataSourceAdapter) InterpretDataQualityCheck(check *dbqcor
229229
return sqlQuery, nil
230230
}
231231

232-
func (a *PostgresqlDbqDataSourceAdapter) ExecuteQuery(ctx context.Context, query string) (string, error) {
232+
func (a *PostgresqlDbqDataSourceAdapter) ExecuteQuery(ctx context.Context, query string) (interface{}, error) {
233233
rows, err := a.db.QueryContext(ctx, query)
234234
if err != nil {
235235
return "", fmt.Errorf("failed to execute query for check: %v", err)
236236
}
237237
defer rows.Close()
238238

239-
var queryResult string
239+
var queryResult interface{}
240240
for rows.Next() {
241241
if err := rows.Scan(&queryResult); err != nil {
242242
return "", fmt.Errorf("failed to scan result for check: %v", err)

dbq/dbq.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
)
2727

2828
const (
29-
Version = "v0.5.0"
29+
Version = "v0.5.1"
3030
)
3131

3232
func GetDbqCoreLibVersion() string {

dbq_validator.go

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type DbqDataSourceAdapter interface {
4949
InterpretDataQualityCheck(check *DataQualityCheck, dataset string, defaultWhere string) (string, error)
5050

5151
// ExecuteQuery executes the SQL query and returns the query result
52-
ExecuteQuery(ctx context.Context, query string) (string, error)
52+
ExecuteQuery(ctx context.Context, query string) (interface{}, error)
5353
}
5454

5555
func NewDbqDataValidator(logger *slog.Logger) DbqDataValidator {
@@ -94,37 +94,37 @@ func (d DbqDataValidatorImpl) RunCheck(ctx context.Context, adapter DbqDataSourc
9494
"check_expression", check.Expression,
9595
"duration_ms", elapsed)
9696

97-
result.QueryResultValue = queryResult
97+
result.QueryResultValue = fmt.Sprintf("%v", queryResult)
9898

9999
// Handle schema checks specially
100100
if check.SchemaCheck != nil {
101101
// For schema checks, we expect the count to match the expected value
102102
if check.SchemaCheck.ExpectColumnsOrdered != nil {
103103
// For expect_columns_ordered, the count should match the number of expected columns
104104
expectedCount := len(check.SchemaCheck.ExpectColumnsOrdered.ColumnsOrder)
105-
actualCount, err := strconv.Atoi(queryResult)
105+
actualCount, err := d.convertToInt(queryResult)
106106
if err != nil || actualCount != expectedCount {
107107
result.Pass = false
108-
result.Error = fmt.Sprintf("Check failed: %s == %d (got: %s)", check.Expression, expectedCount, queryResult)
108+
result.Error = fmt.Sprintf("Check failed: %s == %d (got: %v)", check.Expression, expectedCount, queryResult)
109109
} else {
110110
result.Pass = true
111111
}
112112
} else if check.SchemaCheck.ExpectColumns != nil {
113113
// For expect_columns, the count should match the number of expected columns
114114
expectedCount := len(check.SchemaCheck.ExpectColumns.Columns)
115-
actualCount, err := strconv.Atoi(queryResult)
115+
actualCount, err := d.convertToInt(queryResult)
116116
if err != nil || actualCount != expectedCount {
117117
result.Pass = false
118-
result.Error = fmt.Sprintf("Check failed: %s == %d (got: %s)", check.Expression, expectedCount, queryResult)
118+
result.Error = fmt.Sprintf("Check failed: %s == %d (got: %v)", check.Expression, expectedCount, queryResult)
119119
} else {
120120
result.Pass = true
121121
}
122122
} else if check.SchemaCheck.ColumnsNotPresent != nil {
123123
// For columns_not_present, the count should be 0 (no unwanted columns should exist)
124-
actualCount, err := strconv.Atoi(queryResult)
124+
actualCount, err := d.convertToInt(queryResult)
125125
if err != nil {
126126
result.Pass = false
127-
result.Error = fmt.Sprintf("Check failed: %s invalid result: %s", check.Expression, queryResult)
127+
result.Error = fmt.Sprintf("Check failed: %s invalid result: %v", check.Expression, queryResult)
128128
} else if actualCount > 0 {
129129
result.Pass = false
130130
result.Error = fmt.Sprintf("Check failed: %s found %d unwanted columns", check.Expression, actualCount)
@@ -144,19 +144,19 @@ func (d DbqDataValidatorImpl) RunCheck(ctx context.Context, adapter DbqDataSourc
144144
}
145145

146146
// validateResult checks if the query result meets the check criteria
147-
func (d DbqDataValidatorImpl) validateResult(queryResult string, parsedCheck *CheckExpression) bool {
147+
func (d DbqDataValidatorImpl) validateResult(queryResult interface{}, parsedCheck *CheckExpression) bool {
148148
if parsedCheck == nil {
149149
// If there's no parsed check, consider it a pass (raw queries without validation)
150150
return true
151151
}
152152

153153
// If there's no operator, just check if we got a result (for functions like raw_query)
154154
if parsedCheck.Operator == "" {
155-
return queryResult != ""
155+
return queryResult != nil && fmt.Sprintf("%v", queryResult) != ""
156156
}
157157

158158
// Convert query result to float64 for numeric comparisons
159-
actualValue, err := strconv.ParseFloat(queryResult, 64)
159+
actualValue, err := d.convertToFloat64(queryResult)
160160
if err != nil {
161161
d.logger.Warn("Failed to parse query result as number, treating as string comparison",
162162
"result", queryResult,
@@ -187,22 +187,23 @@ func (d DbqDataValidatorImpl) validateResult(queryResult string, parsedCheck *Ch
187187
}
188188

189189
// validateStringResult handles string-based comparisons when numeric parsing fails
190-
func (d DbqDataValidatorImpl) validateStringResult(queryResult string, parsedCheck *CheckExpression) bool {
190+
func (d DbqDataValidatorImpl) validateStringResult(queryResult interface{}, parsedCheck *CheckExpression) bool {
191+
queryResultStr := fmt.Sprintf("%v", queryResult)
191192
switch parsedCheck.Operator {
192193
case "==", "=":
193194
if thresholdStr, ok := parsedCheck.ThresholdValue.(string); ok {
194-
return queryResult == thresholdStr
195+
return queryResultStr == thresholdStr
195196
}
196-
return queryResult == fmt.Sprintf("%v", parsedCheck.ThresholdValue)
197+
return queryResultStr == fmt.Sprintf("%v", parsedCheck.ThresholdValue)
197198
case "!=", "<>":
198199
if thresholdStr, ok := parsedCheck.ThresholdValue.(string); ok {
199-
return queryResult != thresholdStr
200+
return queryResultStr != thresholdStr
200201
}
201-
return queryResult != fmt.Sprintf("%v", parsedCheck.ThresholdValue)
202+
return queryResultStr != fmt.Sprintf("%v", parsedCheck.ThresholdValue)
202203
default:
203204
d.logger.Warn("String comparison not supported for operator, defaulting to false",
204205
"operator", parsedCheck.Operator,
205-
"result", queryResult)
206+
"result", queryResultStr)
206207
return false
207208
}
208209
}
@@ -340,3 +341,37 @@ func (d DbqDataValidatorImpl) convertToFloat64(value interface{}) (float64, erro
340341
return 0, fmt.Errorf("unsupported type: %T", value)
341342
}
342343
}
344+
345+
// convertToInt converts various types to int
346+
func (d DbqDataValidatorImpl) convertToInt(value interface{}) (int, error) {
347+
switch v := value.(type) {
348+
case int:
349+
return v, nil
350+
case int8:
351+
return int(v), nil
352+
case int16:
353+
return int(v), nil
354+
case int32:
355+
return int(v), nil
356+
case int64:
357+
return int(v), nil
358+
case uint:
359+
return int(v), nil
360+
case uint8:
361+
return int(v), nil
362+
case uint16:
363+
return int(v), nil
364+
case uint32:
365+
return int(v), nil
366+
case uint64:
367+
return int(v), nil
368+
case float32:
369+
return int(v), nil
370+
case float64:
371+
return int(v), nil
372+
case string:
373+
return strconv.Atoi(v)
374+
default:
375+
return 0, fmt.Errorf("unsupported type: %T", value)
376+
}
377+
}

dbq_validator_columns_not_present_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ func (m *MockColumnsNotPresentAdapter) InterpretDataQualityCheck(check *DataQual
1717
return "", fmt.Errorf("not a columns_not_present check")
1818
}
1919

20-
func (m *MockColumnsNotPresentAdapter) ExecuteQuery(ctx context.Context, query string) (string, error) {
20+
func (m *MockColumnsNotPresentAdapter) ExecuteQuery(ctx context.Context, query string) (interface{}, error) {
2121
// Return different counts based on the query to test different scenarios
2222
if query == "SELECT COUNT(*) FROM columns WHERE unwanted = true" {
23-
return "0", nil // No unwanted columns found - check should pass
23+
return 0, nil // No unwanted columns found - check should pass
2424
}
2525
if query == "SELECT COUNT(*) FROM columns WHERE unwanted = true WITH UNWANTED" {
26-
return "3", nil // 3 unwanted columns found - check should fail
26+
return 3, nil // 3 unwanted columns found - check should fail
2727
}
2828
if query == "SELECT COUNT(*) FROM columns WHERE unwanted = true WITH ERROR" {
2929
return "invalid", nil // Invalid result - check should fail
3030
}
31-
return "0", nil
31+
return 0, nil
3232
}
3333

3434
func TestValidateColumnsNotPresent(t *testing.T) {
@@ -107,17 +107,18 @@ func TestValidateColumnsNotPresent(t *testing.T) {
107107

108108
// Validate result
109109
result := ValidationResult{
110-
QueryResultValue: queryResult,
110+
QueryResultValue: fmt.Sprintf("%v", queryResult),
111111
}
112112

113113
// Simulate the validator logic for columns_not_present
114114
if tt.check.SchemaCheck != nil && tt.check.SchemaCheck.ColumnsNotPresent != nil {
115115
count := 0
116-
fmt.Sscanf(queryResult, "%d", &count)
116+
queryResultStr := fmt.Sprintf("%v", queryResult)
117+
fmt.Sscanf(queryResultStr, "%d", &count)
117118

118-
if queryResult == "invalid" {
119+
if queryResultStr == "invalid" {
119120
result.Pass = false
120-
result.Error = fmt.Sprintf("Check failed: %s invalid result: %s", tt.check.Expression, queryResult)
121+
result.Error = fmt.Sprintf("Check failed: %s invalid result: %s", tt.check.Expression, queryResultStr)
121122
} else if count > 0 {
122123
result.Pass = false
123124
result.Error = fmt.Sprintf("Check failed: %s found %d unwanted columns", tt.check.Expression, count)

dbq_validator_expect_columns_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
// MockAdapterForExpectColumns is a test adapter for expect_columns checks
1111
type MockAdapterForExpectColumns struct {
1212
expectedQuery string
13-
returnValue string
13+
returnValue interface{}
1414
returnError error
1515
}
1616

@@ -22,7 +22,7 @@ func (m *MockAdapterForExpectColumns) InterpretDataQualityCheck(check *DataQuali
2222
return "", nil
2323
}
2424

25-
func (m *MockAdapterForExpectColumns) ExecuteQuery(ctx context.Context, query string) (string, error) {
25+
func (m *MockAdapterForExpectColumns) ExecuteQuery(ctx context.Context, query string) (interface{}, error) {
2626
m.expectedQuery = query
2727
return m.returnValue, m.returnError
2828
}
@@ -31,7 +31,7 @@ func TestDbqDataValidator_ExpectColumns(t *testing.T) {
3131
tests := []struct {
3232
name string
3333
check DataQualityCheck
34-
queryResult string
34+
queryResult interface{}
3535
expectedPassed bool
3636
expectedReason string
3737
}{
@@ -45,7 +45,7 @@ func TestDbqDataValidator_ExpectColumns(t *testing.T) {
4545
},
4646
},
4747
},
48-
queryResult: "3",
48+
queryResult: 3,
4949
expectedPassed: true,
5050
expectedReason: "",
5151
},
@@ -60,7 +60,7 @@ func TestDbqDataValidator_ExpectColumns(t *testing.T) {
6060
},
6161
},
6262
},
63-
queryResult: "2",
63+
queryResult: 2,
6464
expectedPassed: false,
6565
expectedReason: "Check failed: expect_columns == 3 (got: 2)",
6666
},
@@ -75,7 +75,7 @@ func TestDbqDataValidator_ExpectColumns(t *testing.T) {
7575
},
7676
},
7777
},
78-
queryResult: "0",
78+
queryResult: 0,
7979
expectedPassed: false,
8080
expectedReason: "Check failed: expect_columns == 2 (got: 0)",
8181
},
@@ -89,7 +89,7 @@ func TestDbqDataValidator_ExpectColumns(t *testing.T) {
8989
},
9090
},
9191
},
92-
queryResult: "1",
92+
queryResult: 1,
9393
expectedPassed: true,
9494
expectedReason: "",
9595
},

0 commit comments

Comments
 (0)