Deeptanshuu's picture
Upload folder using huggingface_hub
d187b57 verified
import numpy as np
import pandas as pd
import json
from typing import Dict, List
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def validate_parameters(params: Dict) -> Dict:
"""
Validate weight calculation parameters to prevent dangerous combinations.
Includes validation for focal loss parameters.
"""
# Check for dangerous weight scaling
if params['boost_factor'] * params['max_weight'] > 30:
raise ValueError(f"Dangerous weight scaling detected: boost_factor * max_weight = {params['boost_factor'] * params['max_weight']}")
# Validate focal loss parameters
if not 0 < params['gamma'] <= 5.0:
raise ValueError(f"Invalid gamma value: {params['gamma']}. Must be in (0, 5.0]")
if not 0 < params['alpha'] < 1:
raise ValueError(f"Invalid alpha value: {params['alpha']}. Must be in (0, 1)")
# Check for potentially unstable combinations
if params['gamma'] > 3.0 and params['boost_factor'] > 1.5:
logging.warning(f"Potentially unstable combination: high gamma ({params['gamma']}) with high boost factor ({params['boost_factor']})")
if params['alpha'] > 0.4 and params['boost_factor'] > 1.5:
logging.warning(f"Potentially unstable combination: high alpha ({params['alpha']}) with high boost factor ({params['boost_factor']})")
return params
def calculate_safe_weights(
support_0: int,
support_1: int,
max_weight: float = 15.0,
min_weight: float = 0.5,
gamma: float = 2.0,
alpha: float = 0.25,
boost_factor: float = 1.0,
num_classes: int = 6,
lang: str = None,
toxicity_type: str = None
) -> Dict[str, float]:
"""
Calculate class weights with focal loss and adaptive scaling.
Uses focal loss components for better handling of imbalanced classes
while preserving language-specific adjustments.
Args:
support_0: Number of negative samples
support_1: Number of positive samples
max_weight: Maximum allowed weight
min_weight: Minimum allowed weight
gamma: Focal loss gamma parameter for down-weighting easy examples
alpha: Focal loss alpha parameter for balancing positive/negative classes
boost_factor: Optional boost for specific classes
num_classes: Number of toxicity classes (default=6)
lang: Language code for language-specific constraints
toxicity_type: Type of toxicity for class-specific constraints
"""
# Input validation with detailed error messages
if support_0 < 0 or support_1 < 0:
raise ValueError(f"Negative sample counts: support_0={support_0}, support_1={support_1}")
eps = 1e-7 # Small epsilon for numerical stability
total = support_0 + support_1 + eps
# Handle empty dataset case
if total <= eps:
logging.warning(f"Empty dataset for {toxicity_type} in {lang}")
return {
"0": 1.0,
"1": 1.0,
"support_0": support_0,
"support_1": support_1,
"raw_weight_1": 1.0,
"calculation_metadata": {
"formula": "default_weights_empty_dataset",
"constraints_applied": ["empty_dataset_fallback"]
}
}
# Handle zero support cases safely
if support_1 == 0:
logging.warning(f"No positive samples for {toxicity_type} in {lang}")
return {
"0": 1.0,
"1": max_weight,
"support_0": support_0,
"support_1": support_1,
"raw_weight_1": max_weight,
"calculation_metadata": {
"formula": "max_weight_no_positives",
"constraints_applied": ["no_positives_fallback"]
}
}
# Determine effective maximum weight based on class and language
if lang == 'en' and toxicity_type == 'threat':
effective_max = min(max_weight, 15.0) # Absolute cap for EN threat
elif toxicity_type == 'identity_hate':
effective_max = min(max_weight, 10.0) # Cap for identity hate
else:
effective_max = max_weight
try:
# Calculate class frequencies
freq_1 = support_1 / total
freq_0 = support_0 / total
# Focal loss components
pt = freq_1 + eps # Probability of target class
modulating_factor = (1 - pt) ** gamma
balanced_alpha = alpha / (alpha + (1 - alpha) * (1 - pt))
# Base weight calculation with focal loss
raw_weight_1 = balanced_alpha * modulating_factor / (pt + eps)
# Apply adaptive scaling for severe classes
if toxicity_type in ['threat', 'identity_hate']:
severity_factor = (1 + np.log1p(total) / np.log1p(support_1)) / 2
raw_weight_1 *= severity_factor
# Apply boost factor
raw_weight_1 *= boost_factor
# Detect potential numerical instability
if not np.isfinite(raw_weight_1):
logging.error(f"Numerical instability detected for {toxicity_type} in {lang}")
raw_weight_1 = effective_max
except Exception as e:
logging.error(f"Weight calculation error: {str(e)}")
raw_weight_1 = effective_max
# Apply safety limits with effective maximum
weight_1 = min(effective_max, max(min_weight, raw_weight_1))
weight_0 = 1.0 # Reference weight for majority class
# Round weights for consistency and to prevent floating point issues
weight_1 = round(float(weight_1), 3)
weight_0 = round(float(weight_0), 3)
return {
"0": weight_0,
"1": weight_1,
"support_0": support_0,
"support_1": support_1,
"raw_weight_1": round(float(raw_weight_1), 3),
"calculation_metadata": {
"formula": "focal_loss_with_adaptive_scaling",
"gamma": round(float(gamma), 3),
"alpha": round(float(alpha), 3),
"final_pt": round(float(pt), 4),
"effective_max": round(float(effective_max), 3),
"modulating_factor": round(float(modulating_factor), 4),
"balanced_alpha": round(float(balanced_alpha), 4),
"severity_adjusted": toxicity_type in ['threat', 'identity_hate'],
"boost_factor": round(float(boost_factor), 3),
"constraints_applied": [
f"max_weight={effective_max}",
f"boost={boost_factor}",
f"numerical_stability=enforced",
f"adaptive_scaling={'enabled' if toxicity_type in ['threat', 'identity_hate'] else 'disabled'}"
]
}
}
def get_language_specific_params(lang: str, toxicity_type: str) -> Dict:
"""
Get language and class specific parameters for weight calculation.
Includes focal loss parameters and their adjustments per language/class.
"""
# Default parameters
default_params = {
"max_weight": 15.0,
"min_weight": 0.5,
"boost_factor": 1.0,
"gamma": 2.0, # Default focal loss gamma
"alpha": 0.25 # Default focal loss alpha
}
# Updated language-specific adjustments based on analysis
lang_adjustments = {
"en": {
"toxic": {
"boost_factor": 1.67, # To achieve ~3.5x weight
"gamma": 2.5 # More focus on hard examples for main class
},
"threat": {
"max_weight": 15.0, # Absolute maximum cap
"gamma": 3.0, # Higher gamma for severe class
"alpha": 0.3 # Slightly higher alpha for better recall
},
"identity_hate": {
"max_weight": 5.0, # Reduced from 8.4
"gamma": 3.0, # Higher gamma for severe class
"alpha": 0.3 # Slightly higher alpha for better recall
},
"severe_toxic": {
"max_weight": 3.9, # Corrected weight
"gamma": 2.5 # Moderate gamma for balance
}
},
"tr": {
"threat": {
"max_weight": 12.8, # Aligned with cross-lingual ratio
"gamma": 2.8 # Slightly lower than EN for stability
},
"identity_hate": {
"max_weight": 6.2, # Adjusted for balance
"gamma": 2.8 # Slightly lower than EN for stability
}
},
"ru": {
"threat": {
"max_weight": 12.8, # Aligned with cross-lingual ratio
"gamma": 2.8 # Slightly lower than EN for stability
},
"identity_hate": {
"max_weight": 7.0, # Adjusted for balance
"gamma": 2.8 # Slightly lower than EN for stability
}
},
"fr": {
"toxic": {
"boost_factor": 1.2, # To achieve ~2.2x weight
"gamma": 2.2 # Lower gamma for better stability
}
}
}
# Get language-specific params and validate
lang_params = lang_adjustments.get(lang, {})
class_params = lang_params.get(toxicity_type, {})
merged_params = {**default_params, **class_params}
return validate_parameters(merged_params)
def check_cross_language_consistency(lang_weights: Dict) -> List[str]:
"""
Check for consistency of weights across languages.
Returns a list of warnings for significant disparities.
"""
warnings = []
baseline = lang_weights['en']
for lang in lang_weights:
if lang == 'en':
continue
for cls in ['threat', 'identity_hate']:
if cls in lang_weights[lang] and cls in baseline:
ratio = lang_weights[lang][cls]['1'] / baseline[cls]['1']
if ratio > 1.5 or ratio < 0.67:
warning = f"Large {cls} weight disparity: {lang} vs en ({ratio:.2f}x)"
warnings.append(warning)
logging.warning(warning)
return warnings
def validate_dataset_balance(df: pd.DataFrame) -> bool:
"""
Validate dataset balance across languages.
Returns False if imbalance exceeds threshold.
"""
sample_counts = df.groupby('lang').size()
cv = sample_counts.std() / sample_counts.mean()
if cv > 0.15: # 15% threshold for coefficient of variation
logging.error(f"Dataset language imbalance exceeds 15% (CV={cv:.2%})")
for lang, count in sample_counts.items():
logging.warning(f"{lang}: {count:,} samples ({count/len(df):.1%})")
return False
return True
def validate_weights(lang_weights: Dict) -> List[str]:
"""
Ensure weights meet multilingual safety criteria.
Validates weight ratios and focal loss parameters across languages.
Args:
lang_weights: Dictionary of weights per language and class
Returns:
List of validation warnings
Raises:
ValueError: If weights violate safety constraints
"""
warnings = []
for lang in lang_weights:
for cls in lang_weights[lang]:
w1 = lang_weights[lang][cls]['1']
w0 = lang_weights[lang][cls]['0']
# Check weight ratio sanity
ratio = w1 / w0
if ratio > 30:
raise ValueError(
f"Dangerous weight ratio {ratio:.1f}x for {lang} {cls}. "
f"Weight_1={w1:.3f}, Weight_0={w0:.3f}"
)
elif ratio > 20:
warnings.append(
f"High weight ratio {ratio:.1f}x for {lang} {cls}"
)
# Check focal parameter boundaries
metadata = lang_weights[lang][cls]['calculation_metadata']
gamma = metadata.get('gamma', 0.0)
alpha = metadata.get('alpha', 0.0)
if gamma > 5.0:
raise ValueError(
f"Unsafe gamma={gamma:.1f} for {lang} {cls}. "
f"Must be <= 5.0"
)
elif gamma > 4.0:
warnings.append(
f"High gamma={gamma:.1f} for {lang} {cls}"
)
if alpha > 0.9:
raise ValueError(
f"Unsafe alpha={alpha:.2f} for {lang} {cls}. "
f"Must be < 0.9"
)
elif alpha > 0.7:
warnings.append(
f"High alpha={alpha:.2f} for {lang} {cls}"
)
# Check for combined risk factors
if gamma > 3.0 and ratio > 15:
warnings.append(
f"Risky combination for {lang} {cls}: "
f"gamma={gamma:.1f}, ratio={ratio:.1f}x"
)
return warnings
def compute_language_weights(df: pd.DataFrame) -> Dict:
"""
Compute weights with inter-language normalization to ensure consistent
weighting across languages while preserving relative class relationships.
"""
# Validate dataset balance first
if not validate_dataset_balance(df):
logging.warning("Proceeding with imbalanced dataset - weights may need manual adjustment")
lang_weights = {}
toxicity_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
# First pass: calculate raw weights for each language and class
logging.info("\nFirst pass: Calculating raw weights")
for lang in df['lang'].unique():
logging.info(f"\nProcessing language: {lang}")
lang_df = df[df['lang'] == lang]
lang_weights[lang] = {}
for col in toxicity_columns:
y = lang_df[col].values.astype(np.int32)
support_0 = int((y == 0).sum())
support_1 = int((y == 1).sum())
params = get_language_specific_params(lang, col)
weights = calculate_safe_weights(
support_0=support_0,
support_1=support_1,
max_weight=params['max_weight'],
min_weight=params['min_weight'],
gamma=params['gamma'],
alpha=params['alpha'],
boost_factor=params['boost_factor'],
lang=lang,
toxicity_type=col
)
lang_weights[lang][col] = weights
# Log initial weights
logging.info(f" {col} - Initial weights:")
logging.info(f" Class 0: {weights['0']:.3f}, samples: {support_0:,}")
logging.info(f" Class 1: {weights['1']:.3f}, samples: {support_1:,}")
# Second pass: normalize weights across languages
logging.info("\nSecond pass: Normalizing weights across languages")
for col in toxicity_columns:
# Find maximum weight for this toxicity type across all languages
max_weight = max(
lang_weights[lang][col]['1']
for lang in lang_weights
)
if max_weight > 0: # Prevent division by zero
logging.info(f"\nNormalizing {col}:")
logging.info(f" Maximum weight across languages: {max_weight:.3f}")
# Normalize weights for each language
for lang in lang_weights:
original_weight = lang_weights[lang][col]['1']
# Normalize and rescale
normalized_weight = (original_weight / max_weight) * 15.0
# Update weight while preserving metadata
lang_weights[lang][col]['raw_weight_1'] = original_weight
lang_weights[lang][col]['1'] = round(normalized_weight, 3)
# Add normalization info to metadata
lang_weights[lang][col]['calculation_metadata'].update({
'normalization': {
'original_weight': round(float(original_weight), 3),
'max_weight_across_langs': round(float(max_weight), 3),
'normalization_factor': round(float(15.0 / max_weight), 3)
}
})
# Log normalization results
logging.info(f" {lang}: {original_weight:.3f}{normalized_weight:.3f}")
# Validate final weights
logging.info("\nValidating final weights:")
for col in toxicity_columns:
weights_range = [
lang_weights[lang][col]['1']
for lang in lang_weights
]
logging.info(f" {col}: range [{min(weights_range):.3f}, {max(weights_range):.3f}]")
# Validate weights meet safety criteria
validation_warnings = validate_weights(lang_weights)
if validation_warnings:
logging.warning("\nWeight validation warnings:")
for warning in validation_warnings:
logging.warning(f" {warning}")
# Check cross-language consistency
consistency_warnings = check_cross_language_consistency(lang_weights)
if consistency_warnings:
logging.warning("\nCross-language consistency warnings:")
for warning in consistency_warnings:
logging.warning(f" {warning}")
return lang_weights
def main():
# Load dataset
input_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_AUGMENTED.csv'
logging.info(f"Loading dataset from {input_file}")
df = pd.read_csv(input_file)
# Compute weights
lang_weights = compute_language_weights(df)
# Add metadata
weights_data = {
"metadata": {
"total_samples": len(df),
"language_distribution": df['lang'].value_counts().to_dict(),
"weight_calculation": {
"method": "focal_loss_with_adaptive_scaling",
"parameters": {
"default_max_weight": 15.0,
"default_min_weight": 0.5,
"language_specific_adjustments": True
}
}
},
"weights": lang_weights
}
# Save weights
output_file = 'weights/language_class_weights.json'
logging.info(f"\nSaving weights to {output_file}")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(weights_data, f, indent=2, ensure_ascii=False)
logging.info("\nWeight calculation complete!")
# Print summary statistics
logging.info("\nSummary of adjustments made:")
for lang in lang_weights:
for col in ['threat', 'identity_hate']:
if col in lang_weights[lang]:
weight = lang_weights[lang][col]['1']
raw = lang_weights[lang][col]['raw_weight_1']
if raw != weight:
logging.info(f"{lang} {col}: Adjusted from {raw:.2f}× to {weight:.2f}×")
if __name__ == "__main__":
main()