Deeptanshuu's picture
Upload folder using huggingface_hub
d187b57 verified
import os
import torch
# Configure CPU and thread settings FIRST, before any other imports
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
os.environ['TF_CPU_ENABLE_AVX2'] = '1'
os.environ['TF_CPU_ENABLE_AVX512F'] = '1'
os.environ['TF_CPU_ENABLE_AVX512_VNNI'] = '1'
os.environ['TF_CPU_ENABLE_FMA'] = '1'
os.environ['MKL_NUM_THREADS'] = '80'
os.environ['OMP_NUM_THREADS'] = '80'
# Set PyTorch thread configurations once
torch.set_num_threads(80)
torch.set_num_interop_threads(10)
# Now import everything else
import pandas as pd
import numpy as np
from pathlib import Path
import logging
from datetime import datetime
import sys
from toxic_augment import ToxicAugmenter
import json
# Configure logging
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
log_file = log_dir / f"balance_english_{timestamp}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s | %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler(log_file)
]
)
logger = logging.getLogger(__name__)
def analyze_label_distribution(df, lang='en'):
"""Analyze label distribution for a specific language"""
labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
lang_df = df[df['lang'] == lang]
total = len(lang_df)
if total == 0:
logger.warning(f"No samples found for language {lang.upper()}.")
return {}
logger.info(f"\nLabel Distribution for {lang.upper()}:")
logger.info("-" * 50)
dist = {}
for label in labels:
count = lang_df[label].sum()
percentage = (count / total) * 100
dist[label] = {'count': int(count), 'percentage': percentage}
logger.info(f"{label}: {count:,} ({percentage:.2f}%)")
return dist
def analyze_language_distribution(df):
"""Analyze current language distribution"""
lang_dist = df['lang'].value_counts()
logger.info("\nCurrent Language Distribution:")
logger.info("-" * 50)
for lang, count in lang_dist.items():
logger.info(f"{lang}: {count:,} comments ({count/len(df)*100:.2f}%)")
return lang_dist
def calculate_required_samples(df):
"""Calculate how many English samples we need to generate"""
lang_counts = df['lang'].value_counts()
target_count = lang_counts.max() # Use the largest language count as target
en_count = lang_counts.get('en', 0)
required_samples = target_count - en_count
logger.info(f"\nTarget count per language: {target_count:,}")
logger.info(f"Current English count: {en_count:,}")
logger.info(f"Required additional English samples: {required_samples:,}")
return required_samples
def generate_balanced_samples(df, required_samples):
"""Generate samples maintaining original class distribution ratios"""
logger.info("\nGenerating balanced samples...")
# Get English samples
en_df = df[df['lang'] == 'en']
labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
# Calculate target counts for each label
target_counts = {}
for label in labels:
count = en_df[label].sum()
ratio = count / len(en_df)
target_count = int(ratio * required_samples)
target_counts[label] = target_count
logger.info(f"Target count for {label}: {target_count:,}")
augmented_samples = []
augmenter = ToxicAugmenter()
total_generated = 0
# Generate samples for each label
for label, target_count in target_counts.items():
if target_count == 0:
continue
logger.info(f"\nGenerating {target_count:,} samples for {label}")
# Get seed texts with this label
seed_texts = en_df[en_df[label] == 1]['comment_text'].tolist()
if not seed_texts:
logger.warning(f"No seed texts found for {label}, skipping...")
continue
# Generate samples with 5-minute timeout
new_samples = augmenter.augment_dataset(
target_samples=target_count,
label=label, # Using single label instead of label_combo
seed_texts=seed_texts,
timeout_minutes=5
)
if new_samples is not None and not new_samples.empty:
augmented_samples.append(new_samples)
total_generated += len(new_samples)
# Log progress
logger.info(f"✓ Generated {len(new_samples):,} samples")
logger.info(f"Progress: {total_generated:,}/{required_samples:,}")
# Check if we have reached our global required samples
if total_generated >= required_samples:
logger.info("Reached required sample count, stopping generation")
break
# Combine all generated samples
if augmented_samples:
augmented_df = pd.concat(augmented_samples, ignore_index=True)
augmented_df['lang'] = 'en'
# Ensure we don't exceed the required sample count
if len(augmented_df) > required_samples:
logger.info(f"Trimming excess samples from {len(augmented_df):,} to {required_samples:,}")
augmented_df = augmented_df.head(required_samples)
# Log final class distribution
logger.info("\nFinal class distribution in generated samples:")
for label in labels:
count = augmented_df[label].sum()
percentage = (count / len(augmented_df)) * 100
logger.info(f"{label}: {count:,} ({percentage:.2f}%)")
# Also log clean samples
clean_count = len(augmented_df[augmented_df[labels].sum(axis=1) == 0])
clean_percentage = (clean_count / len(augmented_df)) * 100
logger.info(f"Clean samples: {clean_count:,} ({clean_percentage:.2f}%)")
return augmented_df
else:
raise Exception("Failed to generate any valid samples")
def balance_english_data():
"""Main function to balance English data with other languages"""
try:
# Load dataset
input_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_FINAL.csv'
logger.info(f"Loading dataset from {input_file}")
df = pd.read_csv(input_file)
# Analyze current distribution
logger.info("\nAnalyzing current distribution...")
initial_dist = analyze_language_distribution(df)
initial_label_dist = analyze_label_distribution(df, 'en')
# Calculate required samples
required_samples = calculate_required_samples(df)
if required_samples <= 0:
logger.info("English data is already balanced. No augmentation needed.")
return
# Generate balanced samples
augmented_df = generate_balanced_samples(df, required_samples)
# Merge with original dataset
logger.info("\nMerging datasets...")
output_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_BALANCED.csv'
# Combine datasets
combined_df = pd.concat([df, augmented_df], ignore_index=True)
# Save balanced dataset
combined_df.to_csv(output_file, index=False)
logger.info(f"\nSaved balanced dataset to {output_file}")
# Final distribution check
logger.info("\nFinal distribution after balancing:")
final_dist = analyze_language_distribution(combined_df)
final_label_dist = analyze_label_distribution(combined_df, 'en')
# Save distribution statistics
stats = {
'timestamp': timestamp,
'initial_distribution': {
'languages': initial_dist.to_dict(),
'english_labels': initial_label_dist
},
'final_distribution': {
'languages': final_dist.to_dict(),
'english_labels': final_label_dist
},
'samples_generated': len(augmented_df),
'total_samples': len(combined_df)
}
stats_file = f'logs/balance_stats_{timestamp}.json'
with open(stats_file, 'w') as f:
json.dump(stats, f, indent=2)
logger.info(f"\nSaved balancing statistics to {stats_file}")
except Exception as e:
logger.error(f"Error during balancing: {str(e)}")
raise
def main():
balance_english_data()
if __name__ == "__main__":
logger.info("Starting English data balancing process...")
main()