@@ -49,7 +49,7 @@ type DbqDataSourceAdapter interface {
4949 InterpretDataQualityCheck (check * DataQualityCheck , dataset string , defaultWhere string ) (string , error )
5050
5151 // ExecuteQuery executes the SQL query and returns the query result
52- ExecuteQuery (ctx context.Context , query string ) (string , error )
52+ ExecuteQuery (ctx context.Context , query string ) (interface {} , error )
5353}
5454
5555func NewDbqDataValidator (logger * slog.Logger ) DbqDataValidator {
@@ -94,37 +94,37 @@ func (d DbqDataValidatorImpl) RunCheck(ctx context.Context, adapter DbqDataSourc
9494 "check_expression" , check .Expression ,
9595 "duration_ms" , elapsed )
9696
97- result .QueryResultValue = queryResult
97+ result .QueryResultValue = fmt . Sprintf ( "%v" , queryResult )
9898
9999 // Handle schema checks specially
100100 if check .SchemaCheck != nil {
101101 // For schema checks, we expect the count to match the expected value
102102 if check .SchemaCheck .ExpectColumnsOrdered != nil {
103103 // For expect_columns_ordered, the count should match the number of expected columns
104104 expectedCount := len (check .SchemaCheck .ExpectColumnsOrdered .ColumnsOrder )
105- actualCount , err := strconv . Atoi (queryResult )
105+ actualCount , err := d . convertToInt (queryResult )
106106 if err != nil || actualCount != expectedCount {
107107 result .Pass = false
108- result .Error = fmt .Sprintf ("Check failed: %s == %d (got: %s )" , check .Expression , expectedCount , queryResult )
108+ result .Error = fmt .Sprintf ("Check failed: %s == %d (got: %v )" , check .Expression , expectedCount , queryResult )
109109 } else {
110110 result .Pass = true
111111 }
112112 } else if check .SchemaCheck .ExpectColumns != nil {
113113 // For expect_columns, the count should match the number of expected columns
114114 expectedCount := len (check .SchemaCheck .ExpectColumns .Columns )
115- actualCount , err := strconv . Atoi (queryResult )
115+ actualCount , err := d . convertToInt (queryResult )
116116 if err != nil || actualCount != expectedCount {
117117 result .Pass = false
118- result .Error = fmt .Sprintf ("Check failed: %s == %d (got: %s )" , check .Expression , expectedCount , queryResult )
118+ result .Error = fmt .Sprintf ("Check failed: %s == %d (got: %v )" , check .Expression , expectedCount , queryResult )
119119 } else {
120120 result .Pass = true
121121 }
122122 } else if check .SchemaCheck .ColumnsNotPresent != nil {
123123 // For columns_not_present, the count should be 0 (no unwanted columns should exist)
124- actualCount , err := strconv . Atoi (queryResult )
124+ actualCount , err := d . convertToInt (queryResult )
125125 if err != nil {
126126 result .Pass = false
127- result .Error = fmt .Sprintf ("Check failed: %s invalid result: %s " , check .Expression , queryResult )
127+ result .Error = fmt .Sprintf ("Check failed: %s invalid result: %v " , check .Expression , queryResult )
128128 } else if actualCount > 0 {
129129 result .Pass = false
130130 result .Error = fmt .Sprintf ("Check failed: %s found %d unwanted columns" , check .Expression , actualCount )
@@ -144,19 +144,19 @@ func (d DbqDataValidatorImpl) RunCheck(ctx context.Context, adapter DbqDataSourc
144144}
145145
146146// validateResult checks if the query result meets the check criteria
147- func (d DbqDataValidatorImpl ) validateResult (queryResult string , parsedCheck * CheckExpression ) bool {
147+ func (d DbqDataValidatorImpl ) validateResult (queryResult interface {} , parsedCheck * CheckExpression ) bool {
148148 if parsedCheck == nil {
149149 // If there's no parsed check, consider it a pass (raw queries without validation)
150150 return true
151151 }
152152
153153 // If there's no operator, just check if we got a result (for functions like raw_query)
154154 if parsedCheck .Operator == "" {
155- return queryResult != ""
155+ return queryResult != nil && fmt . Sprintf ( "%v" , queryResult ) != ""
156156 }
157157
158158 // Convert query result to float64 for numeric comparisons
159- actualValue , err := strconv . ParseFloat (queryResult , 64 )
159+ actualValue , err := d . convertToFloat64 (queryResult )
160160 if err != nil {
161161 d .logger .Warn ("Failed to parse query result as number, treating as string comparison" ,
162162 "result" , queryResult ,
@@ -187,22 +187,23 @@ func (d DbqDataValidatorImpl) validateResult(queryResult string, parsedCheck *Ch
187187}
188188
189189// validateStringResult handles string-based comparisons when numeric parsing fails
190- func (d DbqDataValidatorImpl ) validateStringResult (queryResult string , parsedCheck * CheckExpression ) bool {
190+ func (d DbqDataValidatorImpl ) validateStringResult (queryResult interface {}, parsedCheck * CheckExpression ) bool {
191+ queryResultStr := fmt .Sprintf ("%v" , queryResult )
191192 switch parsedCheck .Operator {
192193 case "==" , "=" :
193194 if thresholdStr , ok := parsedCheck .ThresholdValue .(string ); ok {
194- return queryResult == thresholdStr
195+ return queryResultStr == thresholdStr
195196 }
196- return queryResult == fmt .Sprintf ("%v" , parsedCheck .ThresholdValue )
197+ return queryResultStr == fmt .Sprintf ("%v" , parsedCheck .ThresholdValue )
197198 case "!=" , "<>" :
198199 if thresholdStr , ok := parsedCheck .ThresholdValue .(string ); ok {
199- return queryResult != thresholdStr
200+ return queryResultStr != thresholdStr
200201 }
201- return queryResult != fmt .Sprintf ("%v" , parsedCheck .ThresholdValue )
202+ return queryResultStr != fmt .Sprintf ("%v" , parsedCheck .ThresholdValue )
202203 default :
203204 d .logger .Warn ("String comparison not supported for operator, defaulting to false" ,
204205 "operator" , parsedCheck .Operator ,
205- "result" , queryResult )
206+ "result" , queryResultStr )
206207 return false
207208 }
208209}
@@ -340,3 +341,37 @@ func (d DbqDataValidatorImpl) convertToFloat64(value interface{}) (float64, erro
340341 return 0 , fmt .Errorf ("unsupported type: %T" , value )
341342 }
342343}
344+
345+ // convertToInt converts various types to int
346+ func (d DbqDataValidatorImpl ) convertToInt (value interface {}) (int , error ) {
347+ switch v := value .(type ) {
348+ case int :
349+ return v , nil
350+ case int8 :
351+ return int (v ), nil
352+ case int16 :
353+ return int (v ), nil
354+ case int32 :
355+ return int (v ), nil
356+ case int64 :
357+ return int (v ), nil
358+ case uint :
359+ return int (v ), nil
360+ case uint8 :
361+ return int (v ), nil
362+ case uint16 :
363+ return int (v ), nil
364+ case uint32 :
365+ return int (v ), nil
366+ case uint64 :
367+ return int (v ), nil
368+ case float32 :
369+ return int (v ), nil
370+ case float64 :
371+ return int (v ), nil
372+ case string :
373+ return strconv .Atoi (v )
374+ default :
375+ return 0 , fmt .Errorf ("unsupported type: %T" , value )
376+ }
377+ }
0 commit comments