66
77from testgen .common import read_template_sql_file
88from testgen .common .clean_sql import concat_columns
9- from testgen .common .database .database_service import get_flavor_service , replace_params
9+ from testgen .common .database .database_service import get_flavor_service , get_tg_schema , replace_params
1010from testgen .common .models .connection import Connection
1111from testgen .common .models .table_group import TableGroup
1212from testgen .common .models .test_definition import TestRunType , TestScope
@@ -107,7 +107,7 @@ def _get_params(self, test_def: TestExecutionDef | None = None) -> dict:
107107 "TEST_SUITE_ID" : self .test_run .test_suite_id ,
108108 "TEST_RUN_ID" : self .test_run .id ,
109109 "RUN_DATE" : self .run_date ,
110- "SQL_FLAVOR" : self .flavor ,
110+ "SQL_FLAVOR" : self .flavor ,
111111 "VARCHAR_TYPE" : self .flavor_service .varchar_type ,
112112 "QUOTE" : quote ,
113113 }
@@ -116,7 +116,9 @@ def _get_params(self, test_def: TestExecutionDef | None = None) -> dict:
116116 params .update ({
117117 "TEST_TYPE" : test_def .test_type ,
118118 "TEST_DEFINITION_ID" : test_def .id ,
119+ "APP_SCHEMA_NAME" : get_tg_schema (),
119120 "SCHEMA_NAME" : test_def .schema_name ,
121+ "TABLE_GROUPS_ID" : self .table_group .id ,
120122 "TABLE_NAME" : test_def .table_name ,
121123 "COLUMN_NAME" : f"{ quote } { test_def .column_name or '' } { quote } " ,
122124 "COLUMN_NAME_NO_QUOTES" : test_def .column_name ,
@@ -146,7 +148,7 @@ def _get_params(self, test_def: TestExecutionDef | None = None) -> dict:
146148 "MATCH_HAVING_CONDITION" : f"HAVING { test_def .match_having_condition } " if test_def .match_having_condition else "" ,
147149 "CUSTOM_QUERY" : test_def .custom_query ,
148150 "COLUMN_TYPE" : test_def .column_type ,
149- "INPUT_PARAMETERS" : self ._get_input_parameters (test_def ),
151+ "INPUT_PARAMETERS" : self ._get_input_parameters (test_def ),
150152 })
151153 return params
152154
@@ -169,11 +171,11 @@ def _get_query(
169171 query = query .replace (":" , "\\ :" )
170172
171173 return query , None if no_bind else params
172-
174+
173175 def get_active_test_definitions (self ) -> tuple [dict ]:
174176 # Runs on App database
175177 return self ._get_query ("get_active_test_definitions.sql" )
176-
178+
177179 def get_target_identifiers (self , schemas : Iterable [str ]) -> tuple [str , dict ]:
178180 # Runs on Target database
179181 filename = "get_target_identifiers.sql"
@@ -185,7 +187,7 @@ def get_target_identifiers(self, schemas: Iterable[str]) -> tuple[str, dict]:
185187 return self ._get_query (filename , f"flavors/{ self .connection .sql_flavor } /validate_tests" , extra_params = params )
186188 except ModuleNotFoundError :
187189 return self ._get_query (filename , "flavors/generic/validate_tests" , extra_params = params )
188-
190+
189191 def get_test_errors (self , test_defs : list [TestExecutionDef ]) -> list [list [UUID | str | datetime ]]:
190192 return [
191193 [
@@ -205,15 +207,15 @@ def get_test_errors(self, test_defs: list[TestExecutionDef]) -> list[list[UUID |
205207 None , # No result_measure on errors
206208 ] for td in test_defs if td .errors
207209 ]
208-
210+
209211 def disable_invalid_test_definitions (self ) -> tuple [str , dict ]:
210212 # Runs on App database
211213 return self ._get_query ("disable_invalid_test_definitions.sql" )
212-
214+
213215 def update_historic_thresholds (self ) -> tuple [str , dict ]:
214216 # Runs on App database
215217 return self ._get_query ("update_historic_thresholds.sql" )
216-
218+
217219 def run_query_test (self , test_def : TestExecutionDef ) -> tuple [str , dict ]:
218220 # Runs on Target database
219221 folder = "generic" if test_def .template_name .endswith ("_generic.sql" ) else self .flavor
@@ -225,7 +227,7 @@ def run_query_test(self, test_def: TestExecutionDef) -> tuple[str, dict]:
225227 extra_params = {"DATA_SCHEMA" : test_def .schema_name },
226228 test_def = test_def ,
227229 )
228-
230+
229231 def aggregate_cat_tests (
230232 self ,
231233 test_defs : list [TestExecutionDef ],
@@ -265,7 +267,7 @@ def add_query(test_defs: list[TestExecutionDef]) -> str:
265267
266268 aggregate_queries .append ((query , None ))
267269 aggregate_test_defs .append (test_defs )
268-
270+
269271 if single :
270272 for td in test_defs :
271273 # Add separate query for each test
@@ -296,9 +298,9 @@ def add_query(test_defs: list[TestExecutionDef]) -> str:
296298 current_test_defs .append (td )
297299
298300 add_query (current_test_defs )
299-
301+
300302 return aggregate_queries , aggregate_test_defs
301-
303+
302304 def get_cat_test_results (
303305 self ,
304306 aggregate_results : list [AggregateResult ],
@@ -309,7 +311,7 @@ def get_cat_test_results(
309311 test_defs = aggregate_test_defs [result ["query_index" ]]
310312 result_measures = result ["result_measures" ].split ("|" )
311313 result_codes = result ["result_codes" ].split ("," )
312-
314+
313315 for index , td in enumerate (test_defs ):
314316 test_results .append ([
315317 self .test_run .id ,
@@ -329,7 +331,7 @@ def get_cat_test_results(
329331 ])
330332
331333 return test_results
332-
334+
333335 def update_test_results (self ) -> list [tuple [str , dict ]]:
334336 # Runs on App database
335337 return [
0 commit comments