Skip to content

Commit 268f240

Browse files
committed
feat: Add duplicate data detection and warnings
- Added automatic duplicate rate checking for training and edge case datasets - Configurable 5% threshold for duplicate warnings (DUPLICATE_RATE_THRESHOLD) - Only shows warnings when duplicates could harm model performance - Switched default data generation model to mistralai/mistral-nemo - Enhanced system prompts for better violation content generation - Fixed GitHub workflow argument parsing for problem descriptions This improves data quality monitoring and model performance by alerting users when duplicate training data could cause overfitting or poor generalization.
1 parent 52bb866 commit 268f240

2 files changed

Lines changed: 110 additions & 1 deletion

File tree

text_classifier/agent.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,96 @@ def _append_to_dataset(
444444
)
445445
return 0
446446

447+
def _check_dataset_duplicate_rate(self, dataset_path: Path) -> Dict[str, Any]:
448+
"""
449+
Check for duplicate entries in the training dataset and calculate duplicate rate.
450+
451+
Returns:
452+
dict: Contains duplicate statistics including rate, count, and whether it exceeds threshold
453+
"""
454+
if not dataset_path or not dataset_path.exists():
455+
logger.error(f"Dataset file not found: {dataset_path}")
456+
return {"error": "Dataset file not found"}
457+
458+
try:
459+
import pandas as pd
460+
461+
# Read the dataset
462+
df = pd.read_csv(dataset_path, on_bad_lines="skip")
463+
464+
if df.empty:
465+
logger.warning("Dataset is empty")
466+
return {"error": "Dataset is empty"}
467+
468+
if 'text' not in df.columns:
469+
logger.error("Dataset missing 'text' column")
470+
return {"error": "Dataset missing 'text' column"}
471+
472+
# Clean and normalize text for duplicate detection
473+
df['text_normalized'] = df['text'].astype(str).str.strip().str.lower()
474+
475+
# Calculate duplicate statistics
476+
total_samples = len(df)
477+
unique_samples = df['text_normalized'].nunique()
478+
duplicate_count = total_samples - unique_samples
479+
duplicate_rate = (duplicate_count / total_samples) * 100 if total_samples > 0 else 0
480+
481+
# Check if duplicate rate exceeds threshold
482+
threshold = getattr(settings, 'DUPLICATE_RATE_THRESHOLD', 5.0)
483+
exceeds_threshold = duplicate_rate > threshold
484+
485+
# Get some example duplicates for reporting
486+
duplicate_examples = []
487+
if duplicate_count > 0:
488+
duplicated_texts = df[df.duplicated(subset=['text_normalized'], keep=False)]
489+
if not duplicated_texts.empty:
490+
# Group by normalized text and get counts
491+
duplicate_groups = duplicated_texts.groupby('text_normalized')['text'].apply(list).head(3)
492+
for normalized_text, text_list in duplicate_groups.items():
493+
duplicate_examples.append({
494+
"text": text_list[0][:100] + "..." if len(text_list[0]) > 100 else text_list[0],
495+
"count": len(text_list)
496+
})
497+
498+
return {
499+
"total_samples": total_samples,
500+
"unique_samples": unique_samples,
501+
"duplicate_count": duplicate_count,
502+
"duplicate_rate": round(duplicate_rate, 2),
503+
"exceeds_threshold": exceeds_threshold,
504+
"threshold": threshold,
505+
"examples": duplicate_examples[:3] # Show up to 3 examples
506+
}
507+
508+
except Exception as e:
509+
logger.error(f"Error checking dataset duplicates: {e}", exc_info=True)
510+
return {"error": f"Error checking duplicates: {str(e)}"}
511+
512+
def _notify_duplicate_rate(self, duplicate_stats: Dict[str, Any], dataset_name: str = "training") -> None:
513+
"""
514+
Notify user only when duplicate rate could harm model performance.
515+
516+
Args:
517+
duplicate_stats: Dictionary containing duplicate statistics
518+
dataset_name: Name of the dataset (e.g., "training", "edge_case")
519+
"""
520+
if "error" in duplicate_stats:
521+
return # Silent fail for errors
522+
523+
duplicate_rate = duplicate_stats["duplicate_rate"]
524+
exceeds_threshold = duplicate_stats["exceeds_threshold"]
525+
threshold = duplicate_stats["threshold"]
526+
527+
# Only show notification if duplicates exceed threshold
528+
if exceeds_threshold:
529+
logger.warning(f"\n⚠️ DATA QUALITY WARNING - {dataset_name.upper()} DATASET")
530+
logger.warning(f"🚨 High duplicate rate detected: {duplicate_rate:.2f}% (threshold: {threshold}%)")
531+
logger.warning(f"💡 This may cause poor model performance due to:")
532+
logger.warning(f" • Model overfitting on repeated examples")
533+
logger.warning(f" • Reduced generalization ability")
534+
logger.warning(f" • Biased training patterns")
535+
logger.warning(f"🔧 Consider regenerating data with more diverse prompts\n")
536+
447537
async def _generate_text_samples_batch_async(
448538
self,
449539
prompts_classlabels: List[
@@ -812,6 +902,14 @@ async def _generate_training_data_async(self) -> int:
812902
)
813903
if self.dataset_path and self.dataset_path.exists():
814904
logger.info(f"Training dataset saved to: {self.dataset_path}")
905+
906+
# Check for duplicate rates in training data
907+
duplicate_stats = self._check_dataset_duplicate_rate(self.dataset_path)
908+
self._notify_duplicate_rate(duplicate_stats, "training")
909+
910+
# Store duplicate stats in final config for reference
911+
if self.final_config:
912+
self.final_config["training_duplicate_stats"] = duplicate_stats
815913
else:
816914
logger.warning(
817915
f"Training dataset file not found at expected location: {self.dataset_path}"
@@ -952,6 +1050,14 @@ async def _generate_edge_cases_async(self) -> int:
9521050
)
9531051
if self.edge_case_dataset_path and self.edge_case_dataset_path.exists():
9541052
logger.info(f"Edge case dataset saved to: {self.edge_case_dataset_path}")
1053+
1054+
# Check for duplicate rates in edge case data
1055+
edge_duplicate_stats = self._check_dataset_duplicate_rate(self.edge_case_dataset_path)
1056+
self._notify_duplicate_rate(edge_duplicate_stats, "edge case")
1057+
1058+
# Store duplicate stats in final config for reference
1059+
if self.final_config:
1060+
self.final_config["edge_case_duplicate_stats"] = edge_duplicate_stats
9551061
else:
9561062
logger.warning(
9571063
f"Edge case dataset file not found at expected location: {self.edge_case_dataset_path}"

text_classifier/settings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
DEFAULT_CONFIG_MODEL = (
1212
"x-ai/grok-3-beta" # "anthropic/claude-3-opus" # More capable model
1313
)
14-
DEFAULT_DATA_GEN_MODEL = "openai/gpt-4o-mini" # Cheaper/faster for bulk generation
14+
DEFAULT_DATA_GEN_MODEL = "mistralai/mistral-nemo" # Less restrictive for content moderation data
1515

1616
# --- Default Paths ---
1717
DEFAULT_OUTPUT_PATH = "models" # Changed for differentiation
@@ -36,6 +36,9 @@
3636
DEFAULT_PROMPT_REFINEMENT_CYCLES = 1 # How many times to refine prompts
3737
DEFAULT_GENERATE_EDGE_CASES = True
3838

39+
# --- Data Quality Control ---
40+
DUPLICATE_RATE_THRESHOLD = 5.0 # Percentage threshold for duplicate rate warnings
41+
3942
# --- Prompts ---
4043
CONFIG_SYSTEM_PROMPT = "You are an expert AI assistant specializing in data generation and configuration for machine learning. Follow instructions precisely and provide output in the requested JSON format."
4144

0 commit comments

Comments
 (0)