|
import numpy as np |
|
import pandas as pd |
|
import json |
|
from typing import Dict, List |
|
import logging |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
|
|
def validate_parameters(params: Dict) -> Dict: |
|
""" |
|
Validate weight calculation parameters to prevent dangerous combinations. |
|
Includes validation for focal loss parameters. |
|
""" |
|
|
|
if params['boost_factor'] * params['max_weight'] > 30: |
|
raise ValueError(f"Dangerous weight scaling detected: boost_factor * max_weight = {params['boost_factor'] * params['max_weight']}") |
|
|
|
|
|
if not 0 < params['gamma'] <= 5.0: |
|
raise ValueError(f"Invalid gamma value: {params['gamma']}. Must be in (0, 5.0]") |
|
|
|
if not 0 < params['alpha'] < 1: |
|
raise ValueError(f"Invalid alpha value: {params['alpha']}. Must be in (0, 1)") |
|
|
|
|
|
if params['gamma'] > 3.0 and params['boost_factor'] > 1.5: |
|
logging.warning(f"Potentially unstable combination: high gamma ({params['gamma']}) with high boost factor ({params['boost_factor']})") |
|
|
|
if params['alpha'] > 0.4 and params['boost_factor'] > 1.5: |
|
logging.warning(f"Potentially unstable combination: high alpha ({params['alpha']}) with high boost factor ({params['boost_factor']})") |
|
|
|
return params |
|
|
|
def calculate_safe_weights( |
|
support_0: int, |
|
support_1: int, |
|
max_weight: float = 15.0, |
|
min_weight: float = 0.5, |
|
gamma: float = 2.0, |
|
alpha: float = 0.25, |
|
boost_factor: float = 1.0, |
|
num_classes: int = 6, |
|
lang: str = None, |
|
toxicity_type: str = None |
|
) -> Dict[str, float]: |
|
""" |
|
Calculate class weights with focal loss and adaptive scaling. |
|
Uses focal loss components for better handling of imbalanced classes |
|
while preserving language-specific adjustments. |
|
|
|
Args: |
|
support_0: Number of negative samples |
|
support_1: Number of positive samples |
|
max_weight: Maximum allowed weight |
|
min_weight: Minimum allowed weight |
|
gamma: Focal loss gamma parameter for down-weighting easy examples |
|
alpha: Focal loss alpha parameter for balancing positive/negative classes |
|
boost_factor: Optional boost for specific classes |
|
num_classes: Number of toxicity classes (default=6) |
|
lang: Language code for language-specific constraints |
|
toxicity_type: Type of toxicity for class-specific constraints |
|
""" |
|
|
|
if support_0 < 0 or support_1 < 0: |
|
raise ValueError(f"Negative sample counts: support_0={support_0}, support_1={support_1}") |
|
|
|
eps = 1e-7 |
|
total = support_0 + support_1 + eps |
|
|
|
|
|
if total <= eps: |
|
logging.warning(f"Empty dataset for {toxicity_type} in {lang}") |
|
return { |
|
"0": 1.0, |
|
"1": 1.0, |
|
"support_0": support_0, |
|
"support_1": support_1, |
|
"raw_weight_1": 1.0, |
|
"calculation_metadata": { |
|
"formula": "default_weights_empty_dataset", |
|
"constraints_applied": ["empty_dataset_fallback"] |
|
} |
|
} |
|
|
|
|
|
if support_1 == 0: |
|
logging.warning(f"No positive samples for {toxicity_type} in {lang}") |
|
return { |
|
"0": 1.0, |
|
"1": max_weight, |
|
"support_0": support_0, |
|
"support_1": support_1, |
|
"raw_weight_1": max_weight, |
|
"calculation_metadata": { |
|
"formula": "max_weight_no_positives", |
|
"constraints_applied": ["no_positives_fallback"] |
|
} |
|
} |
|
|
|
|
|
if lang == 'en' and toxicity_type == 'threat': |
|
effective_max = min(max_weight, 15.0) |
|
elif toxicity_type == 'identity_hate': |
|
effective_max = min(max_weight, 10.0) |
|
else: |
|
effective_max = max_weight |
|
|
|
try: |
|
|
|
freq_1 = support_1 / total |
|
freq_0 = support_0 / total |
|
|
|
|
|
pt = freq_1 + eps |
|
modulating_factor = (1 - pt) ** gamma |
|
balanced_alpha = alpha / (alpha + (1 - alpha) * (1 - pt)) |
|
|
|
|
|
raw_weight_1 = balanced_alpha * modulating_factor / (pt + eps) |
|
|
|
|
|
if toxicity_type in ['threat', 'identity_hate']: |
|
severity_factor = (1 + np.log1p(total) / np.log1p(support_1)) / 2 |
|
raw_weight_1 *= severity_factor |
|
|
|
|
|
raw_weight_1 *= boost_factor |
|
|
|
|
|
if not np.isfinite(raw_weight_1): |
|
logging.error(f"Numerical instability detected for {toxicity_type} in {lang}") |
|
raw_weight_1 = effective_max |
|
|
|
except Exception as e: |
|
logging.error(f"Weight calculation error: {str(e)}") |
|
raw_weight_1 = effective_max |
|
|
|
|
|
weight_1 = min(effective_max, max(min_weight, raw_weight_1)) |
|
weight_0 = 1.0 |
|
|
|
|
|
weight_1 = round(float(weight_1), 3) |
|
weight_0 = round(float(weight_0), 3) |
|
|
|
return { |
|
"0": weight_0, |
|
"1": weight_1, |
|
"support_0": support_0, |
|
"support_1": support_1, |
|
"raw_weight_1": round(float(raw_weight_1), 3), |
|
"calculation_metadata": { |
|
"formula": "focal_loss_with_adaptive_scaling", |
|
"gamma": round(float(gamma), 3), |
|
"alpha": round(float(alpha), 3), |
|
"final_pt": round(float(pt), 4), |
|
"effective_max": round(float(effective_max), 3), |
|
"modulating_factor": round(float(modulating_factor), 4), |
|
"balanced_alpha": round(float(balanced_alpha), 4), |
|
"severity_adjusted": toxicity_type in ['threat', 'identity_hate'], |
|
"boost_factor": round(float(boost_factor), 3), |
|
"constraints_applied": [ |
|
f"max_weight={effective_max}", |
|
f"boost={boost_factor}", |
|
f"numerical_stability=enforced", |
|
f"adaptive_scaling={'enabled' if toxicity_type in ['threat', 'identity_hate'] else 'disabled'}" |
|
] |
|
} |
|
} |
|
|
|
def get_language_specific_params(lang: str, toxicity_type: str) -> Dict: |
|
""" |
|
Get language and class specific parameters for weight calculation. |
|
Includes focal loss parameters and their adjustments per language/class. |
|
""" |
|
|
|
default_params = { |
|
"max_weight": 15.0, |
|
"min_weight": 0.5, |
|
"boost_factor": 1.0, |
|
"gamma": 2.0, |
|
"alpha": 0.25 |
|
} |
|
|
|
|
|
lang_adjustments = { |
|
"en": { |
|
"toxic": { |
|
"boost_factor": 1.67, |
|
"gamma": 2.5 |
|
}, |
|
"threat": { |
|
"max_weight": 15.0, |
|
"gamma": 3.0, |
|
"alpha": 0.3 |
|
}, |
|
"identity_hate": { |
|
"max_weight": 5.0, |
|
"gamma": 3.0, |
|
"alpha": 0.3 |
|
}, |
|
"severe_toxic": { |
|
"max_weight": 3.9, |
|
"gamma": 2.5 |
|
} |
|
}, |
|
"tr": { |
|
"threat": { |
|
"max_weight": 12.8, |
|
"gamma": 2.8 |
|
}, |
|
"identity_hate": { |
|
"max_weight": 6.2, |
|
"gamma": 2.8 |
|
} |
|
}, |
|
"ru": { |
|
"threat": { |
|
"max_weight": 12.8, |
|
"gamma": 2.8 |
|
}, |
|
"identity_hate": { |
|
"max_weight": 7.0, |
|
"gamma": 2.8 |
|
} |
|
}, |
|
"fr": { |
|
"toxic": { |
|
"boost_factor": 1.2, |
|
"gamma": 2.2 |
|
} |
|
} |
|
} |
|
|
|
|
|
lang_params = lang_adjustments.get(lang, {}) |
|
class_params = lang_params.get(toxicity_type, {}) |
|
merged_params = {**default_params, **class_params} |
|
|
|
return validate_parameters(merged_params) |
|
|
|
def check_cross_language_consistency(lang_weights: Dict) -> List[str]: |
|
""" |
|
Check for consistency of weights across languages. |
|
Returns a list of warnings for significant disparities. |
|
""" |
|
warnings = [] |
|
baseline = lang_weights['en'] |
|
|
|
for lang in lang_weights: |
|
if lang == 'en': |
|
continue |
|
|
|
for cls in ['threat', 'identity_hate']: |
|
if cls in lang_weights[lang] and cls in baseline: |
|
ratio = lang_weights[lang][cls]['1'] / baseline[cls]['1'] |
|
if ratio > 1.5 or ratio < 0.67: |
|
warning = f"Large {cls} weight disparity: {lang} vs en ({ratio:.2f}x)" |
|
warnings.append(warning) |
|
logging.warning(warning) |
|
|
|
return warnings |
|
|
|
def validate_dataset_balance(df: pd.DataFrame) -> bool: |
|
""" |
|
Validate dataset balance across languages. |
|
Returns False if imbalance exceeds threshold. |
|
""" |
|
sample_counts = df.groupby('lang').size() |
|
cv = sample_counts.std() / sample_counts.mean() |
|
|
|
if cv > 0.15: |
|
logging.error(f"Dataset language imbalance exceeds 15% (CV={cv:.2%})") |
|
for lang, count in sample_counts.items(): |
|
logging.warning(f"{lang}: {count:,} samples ({count/len(df):.1%})") |
|
return False |
|
return True |
|
|
|
def validate_weights(lang_weights: Dict) -> List[str]: |
|
""" |
|
Ensure weights meet multilingual safety criteria. |
|
Validates weight ratios and focal loss parameters across languages. |
|
|
|
Args: |
|
lang_weights: Dictionary of weights per language and class |
|
|
|
Returns: |
|
List of validation warnings |
|
|
|
Raises: |
|
ValueError: If weights violate safety constraints |
|
""" |
|
warnings = [] |
|
|
|
for lang in lang_weights: |
|
for cls in lang_weights[lang]: |
|
w1 = lang_weights[lang][cls]['1'] |
|
w0 = lang_weights[lang][cls]['0'] |
|
|
|
|
|
ratio = w1 / w0 |
|
if ratio > 30: |
|
raise ValueError( |
|
f"Dangerous weight ratio {ratio:.1f}x for {lang} {cls}. " |
|
f"Weight_1={w1:.3f}, Weight_0={w0:.3f}" |
|
) |
|
elif ratio > 20: |
|
warnings.append( |
|
f"High weight ratio {ratio:.1f}x for {lang} {cls}" |
|
) |
|
|
|
|
|
metadata = lang_weights[lang][cls]['calculation_metadata'] |
|
gamma = metadata.get('gamma', 0.0) |
|
alpha = metadata.get('alpha', 0.0) |
|
|
|
if gamma > 5.0: |
|
raise ValueError( |
|
f"Unsafe gamma={gamma:.1f} for {lang} {cls}. " |
|
f"Must be <= 5.0" |
|
) |
|
elif gamma > 4.0: |
|
warnings.append( |
|
f"High gamma={gamma:.1f} for {lang} {cls}" |
|
) |
|
|
|
if alpha > 0.9: |
|
raise ValueError( |
|
f"Unsafe alpha={alpha:.2f} for {lang} {cls}. " |
|
f"Must be < 0.9" |
|
) |
|
elif alpha > 0.7: |
|
warnings.append( |
|
f"High alpha={alpha:.2f} for {lang} {cls}" |
|
) |
|
|
|
|
|
if gamma > 3.0 and ratio > 15: |
|
warnings.append( |
|
f"Risky combination for {lang} {cls}: " |
|
f"gamma={gamma:.1f}, ratio={ratio:.1f}x" |
|
) |
|
|
|
return warnings |
|
|
|
def compute_language_weights(df: pd.DataFrame) -> Dict: |
|
""" |
|
Compute weights with inter-language normalization to ensure consistent |
|
weighting across languages while preserving relative class relationships. |
|
""" |
|
|
|
if not validate_dataset_balance(df): |
|
logging.warning("Proceeding with imbalanced dataset - weights may need manual adjustment") |
|
|
|
lang_weights = {} |
|
toxicity_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] |
|
|
|
|
|
logging.info("\nFirst pass: Calculating raw weights") |
|
for lang in df['lang'].unique(): |
|
logging.info(f"\nProcessing language: {lang}") |
|
lang_df = df[df['lang'] == lang] |
|
lang_weights[lang] = {} |
|
|
|
for col in toxicity_columns: |
|
y = lang_df[col].values.astype(np.int32) |
|
support_0 = int((y == 0).sum()) |
|
support_1 = int((y == 1).sum()) |
|
|
|
params = get_language_specific_params(lang, col) |
|
weights = calculate_safe_weights( |
|
support_0=support_0, |
|
support_1=support_1, |
|
max_weight=params['max_weight'], |
|
min_weight=params['min_weight'], |
|
gamma=params['gamma'], |
|
alpha=params['alpha'], |
|
boost_factor=params['boost_factor'], |
|
lang=lang, |
|
toxicity_type=col |
|
) |
|
lang_weights[lang][col] = weights |
|
|
|
|
|
logging.info(f" {col} - Initial weights:") |
|
logging.info(f" Class 0: {weights['0']:.3f}, samples: {support_0:,}") |
|
logging.info(f" Class 1: {weights['1']:.3f}, samples: {support_1:,}") |
|
|
|
|
|
logging.info("\nSecond pass: Normalizing weights across languages") |
|
for col in toxicity_columns: |
|
|
|
max_weight = max( |
|
lang_weights[lang][col]['1'] |
|
for lang in lang_weights |
|
) |
|
|
|
if max_weight > 0: |
|
logging.info(f"\nNormalizing {col}:") |
|
logging.info(f" Maximum weight across languages: {max_weight:.3f}") |
|
|
|
|
|
for lang in lang_weights: |
|
original_weight = lang_weights[lang][col]['1'] |
|
|
|
|
|
normalized_weight = (original_weight / max_weight) * 15.0 |
|
|
|
|
|
lang_weights[lang][col]['raw_weight_1'] = original_weight |
|
lang_weights[lang][col]['1'] = round(normalized_weight, 3) |
|
|
|
|
|
lang_weights[lang][col]['calculation_metadata'].update({ |
|
'normalization': { |
|
'original_weight': round(float(original_weight), 3), |
|
'max_weight_across_langs': round(float(max_weight), 3), |
|
'normalization_factor': round(float(15.0 / max_weight), 3) |
|
} |
|
}) |
|
|
|
|
|
logging.info(f" {lang}: {original_weight:.3f} → {normalized_weight:.3f}") |
|
|
|
|
|
logging.info("\nValidating final weights:") |
|
for col in toxicity_columns: |
|
weights_range = [ |
|
lang_weights[lang][col]['1'] |
|
for lang in lang_weights |
|
] |
|
logging.info(f" {col}: range [{min(weights_range):.3f}, {max(weights_range):.3f}]") |
|
|
|
|
|
validation_warnings = validate_weights(lang_weights) |
|
if validation_warnings: |
|
logging.warning("\nWeight validation warnings:") |
|
for warning in validation_warnings: |
|
logging.warning(f" {warning}") |
|
|
|
|
|
consistency_warnings = check_cross_language_consistency(lang_weights) |
|
if consistency_warnings: |
|
logging.warning("\nCross-language consistency warnings:") |
|
for warning in consistency_warnings: |
|
logging.warning(f" {warning}") |
|
|
|
return lang_weights |
|
|
|
def main(): |
|
|
|
input_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_AUGMENTED.csv' |
|
logging.info(f"Loading dataset from {input_file}") |
|
df = pd.read_csv(input_file) |
|
|
|
|
|
lang_weights = compute_language_weights(df) |
|
|
|
|
|
weights_data = { |
|
"metadata": { |
|
"total_samples": len(df), |
|
"language_distribution": df['lang'].value_counts().to_dict(), |
|
"weight_calculation": { |
|
"method": "focal_loss_with_adaptive_scaling", |
|
"parameters": { |
|
"default_max_weight": 15.0, |
|
"default_min_weight": 0.5, |
|
"language_specific_adjustments": True |
|
} |
|
} |
|
}, |
|
"weights": lang_weights |
|
} |
|
|
|
|
|
output_file = 'weights/language_class_weights.json' |
|
logging.info(f"\nSaving weights to {output_file}") |
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
json.dump(weights_data, f, indent=2, ensure_ascii=False) |
|
|
|
logging.info("\nWeight calculation complete!") |
|
|
|
|
|
logging.info("\nSummary of adjustments made:") |
|
for lang in lang_weights: |
|
for col in ['threat', 'identity_hate']: |
|
if col in lang_weights[lang]: |
|
weight = lang_weights[lang][col]['1'] |
|
raw = lang_weights[lang][col]['raw_weight_1'] |
|
if raw != weight: |
|
logging.info(f"{lang} {col}: Adjusted from {raw:.2f}× to {weight:.2f}×") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|