File size: 19,059 Bytes
d187b57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
import numpy as np
import pandas as pd
import json
from typing import Dict, List
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

def validate_parameters(params: Dict) -> Dict:
    """
    Validate weight calculation parameters to prevent dangerous combinations.
    Includes validation for focal loss parameters.
    """
    # Check for dangerous weight scaling
    if params['boost_factor'] * params['max_weight'] > 30:
        raise ValueError(f"Dangerous weight scaling detected: boost_factor * max_weight = {params['boost_factor'] * params['max_weight']}")
    
    # Validate focal loss parameters
    if not 0 < params['gamma'] <= 5.0:
        raise ValueError(f"Invalid gamma value: {params['gamma']}. Must be in (0, 5.0]")
    
    if not 0 < params['alpha'] < 1:
        raise ValueError(f"Invalid alpha value: {params['alpha']}. Must be in (0, 1)")
    
    # Check for potentially unstable combinations
    if params['gamma'] > 3.0 and params['boost_factor'] > 1.5:
        logging.warning(f"Potentially unstable combination: high gamma ({params['gamma']}) with high boost factor ({params['boost_factor']})")
    
    if params['alpha'] > 0.4 and params['boost_factor'] > 1.5:
        logging.warning(f"Potentially unstable combination: high alpha ({params['alpha']}) with high boost factor ({params['boost_factor']})")
    
    return params

def calculate_safe_weights(
    support_0: int,
    support_1: int,
    max_weight: float = 15.0,
    min_weight: float = 0.5,
    gamma: float = 2.0,
    alpha: float = 0.25,
    boost_factor: float = 1.0,
    num_classes: int = 6,
    lang: str = None,
    toxicity_type: str = None
) -> Dict[str, float]:
    """
    Calculate class weights with focal loss and adaptive scaling.
    Uses focal loss components for better handling of imbalanced classes
    while preserving language-specific adjustments.
    
    Args:
        support_0: Number of negative samples
        support_1: Number of positive samples
        max_weight: Maximum allowed weight
        min_weight: Minimum allowed weight
        gamma: Focal loss gamma parameter for down-weighting easy examples
        alpha: Focal loss alpha parameter for balancing positive/negative classes
        boost_factor: Optional boost for specific classes
        num_classes: Number of toxicity classes (default=6)
        lang: Language code for language-specific constraints
        toxicity_type: Type of toxicity for class-specific constraints
    """
    # Input validation with detailed error messages
    if support_0 < 0 or support_1 < 0:
        raise ValueError(f"Negative sample counts: support_0={support_0}, support_1={support_1}")
    
    eps = 1e-7  # Small epsilon for numerical stability
    total = support_0 + support_1 + eps
    
    # Handle empty dataset case
    if total <= eps:
        logging.warning(f"Empty dataset for {toxicity_type} in {lang}")
        return {
            "0": 1.0,
            "1": 1.0,
            "support_0": support_0,
            "support_1": support_1,
            "raw_weight_1": 1.0,
            "calculation_metadata": {
                "formula": "default_weights_empty_dataset",
                "constraints_applied": ["empty_dataset_fallback"]
            }
        }
    
    # Handle zero support cases safely
    if support_1 == 0:
        logging.warning(f"No positive samples for {toxicity_type} in {lang}")
        return {
            "0": 1.0,
            "1": max_weight,
            "support_0": support_0,
            "support_1": support_1,
            "raw_weight_1": max_weight,
            "calculation_metadata": {
                "formula": "max_weight_no_positives",
                "constraints_applied": ["no_positives_fallback"]
            }
        }
    
    # Determine effective maximum weight based on class and language
    if lang == 'en' and toxicity_type == 'threat':
        effective_max = min(max_weight, 15.0)  # Absolute cap for EN threat
    elif toxicity_type == 'identity_hate':
        effective_max = min(max_weight, 10.0)  # Cap for identity hate
    else:
        effective_max = max_weight
    
    try:
        # Calculate class frequencies
        freq_1 = support_1 / total
        freq_0 = support_0 / total
        
        # Focal loss components
        pt = freq_1 + eps  # Probability of target class
        modulating_factor = (1 - pt) ** gamma
        balanced_alpha = alpha / (alpha + (1 - alpha) * (1 - pt))
        
        # Base weight calculation with focal loss
        raw_weight_1 = balanced_alpha * modulating_factor / (pt + eps)
        
        # Apply adaptive scaling for severe classes
        if toxicity_type in ['threat', 'identity_hate']:
            severity_factor = (1 + np.log1p(total) / np.log1p(support_1)) / 2
            raw_weight_1 *= severity_factor
        
        # Apply boost factor
        raw_weight_1 *= boost_factor
        
        # Detect potential numerical instability
        if not np.isfinite(raw_weight_1):
            logging.error(f"Numerical instability detected for {toxicity_type} in {lang}")
            raw_weight_1 = effective_max
            
    except Exception as e:
        logging.error(f"Weight calculation error: {str(e)}")
        raw_weight_1 = effective_max
    
    # Apply safety limits with effective maximum
    weight_1 = min(effective_max, max(min_weight, raw_weight_1))
    weight_0 = 1.0  # Reference weight for majority class
    
    # Round weights for consistency and to prevent floating point issues
    weight_1 = round(float(weight_1), 3)
    weight_0 = round(float(weight_0), 3)
    
    return {
        "0": weight_0,
        "1": weight_1,
        "support_0": support_0,
        "support_1": support_1,
        "raw_weight_1": round(float(raw_weight_1), 3),
        "calculation_metadata": {
            "formula": "focal_loss_with_adaptive_scaling",
            "gamma": round(float(gamma), 3),
            "alpha": round(float(alpha), 3),
            "final_pt": round(float(pt), 4),
            "effective_max": round(float(effective_max), 3),
            "modulating_factor": round(float(modulating_factor), 4),
            "balanced_alpha": round(float(balanced_alpha), 4),
            "severity_adjusted": toxicity_type in ['threat', 'identity_hate'],
            "boost_factor": round(float(boost_factor), 3),
            "constraints_applied": [
                f"max_weight={effective_max}",
                f"boost={boost_factor}",
                f"numerical_stability=enforced",
                f"adaptive_scaling={'enabled' if toxicity_type in ['threat', 'identity_hate'] else 'disabled'}"
            ]
        }
    }

def get_language_specific_params(lang: str, toxicity_type: str) -> Dict:
    """
    Get language and class specific parameters for weight calculation.
    Includes focal loss parameters and their adjustments per language/class.
    """
    # Default parameters
    default_params = {
        "max_weight": 15.0,
        "min_weight": 0.5,
        "boost_factor": 1.0,
        "gamma": 2.0,  # Default focal loss gamma
        "alpha": 0.25  # Default focal loss alpha
    }
    
    # Updated language-specific adjustments based on analysis
    lang_adjustments = {
        "en": {
            "toxic": {
                "boost_factor": 1.67,  # To achieve ~3.5x weight
                "gamma": 2.5  # More focus on hard examples for main class
            },
            "threat": {
                "max_weight": 15.0,  # Absolute maximum cap
                "gamma": 3.0,  # Higher gamma for severe class
                "alpha": 0.3  # Slightly higher alpha for better recall
            },
            "identity_hate": {
                "max_weight": 5.0,  # Reduced from 8.4
                "gamma": 3.0,  # Higher gamma for severe class
                "alpha": 0.3  # Slightly higher alpha for better recall
            },
            "severe_toxic": {
                "max_weight": 3.9,  # Corrected weight
                "gamma": 2.5  # Moderate gamma for balance
            }
        },
        "tr": {
            "threat": {
                "max_weight": 12.8,  # Aligned with cross-lingual ratio
                "gamma": 2.8  # Slightly lower than EN for stability
            },
            "identity_hate": {
                "max_weight": 6.2,  # Adjusted for balance
                "gamma": 2.8  # Slightly lower than EN for stability
            }
        },
        "ru": {
            "threat": {
                "max_weight": 12.8,  # Aligned with cross-lingual ratio
                "gamma": 2.8  # Slightly lower than EN for stability
            },
            "identity_hate": {
                "max_weight": 7.0,  # Adjusted for balance
                "gamma": 2.8  # Slightly lower than EN for stability
            }
        },
        "fr": {
            "toxic": {
                "boost_factor": 1.2,  # To achieve ~2.2x weight
                "gamma": 2.2  # Lower gamma for better stability
            }
        }
    }
    
    # Get language-specific params and validate
    lang_params = lang_adjustments.get(lang, {})
    class_params = lang_params.get(toxicity_type, {})
    merged_params = {**default_params, **class_params}
    
    return validate_parameters(merged_params)

def check_cross_language_consistency(lang_weights: Dict) -> List[str]:
    """
    Check for consistency of weights across languages.
    Returns a list of warnings for significant disparities.
    """
    warnings = []
    baseline = lang_weights['en']
    
    for lang in lang_weights:
        if lang == 'en':
            continue
            
        for cls in ['threat', 'identity_hate']:
            if cls in lang_weights[lang] and cls in baseline:
                ratio = lang_weights[lang][cls]['1'] / baseline[cls]['1']
                if ratio > 1.5 or ratio < 0.67:
                    warning = f"Large {cls} weight disparity: {lang} vs en ({ratio:.2f}x)"
                    warnings.append(warning)
                    logging.warning(warning)
    
    return warnings

def validate_dataset_balance(df: pd.DataFrame) -> bool:
    """
    Validate dataset balance across languages.
    Returns False if imbalance exceeds threshold.
    """
    sample_counts = df.groupby('lang').size()
    cv = sample_counts.std() / sample_counts.mean()
    
    if cv > 0.15:  # 15% threshold for coefficient of variation
        logging.error(f"Dataset language imbalance exceeds 15% (CV={cv:.2%})")
        for lang, count in sample_counts.items():
            logging.warning(f"{lang}: {count:,} samples ({count/len(df):.1%})")
        return False
    return True

def validate_weights(lang_weights: Dict) -> List[str]:
    """
    Ensure weights meet multilingual safety criteria.
    Validates weight ratios and focal loss parameters across languages.
    
    Args:
        lang_weights: Dictionary of weights per language and class
        
    Returns:
        List of validation warnings
        
    Raises:
        ValueError: If weights violate safety constraints
    """
    warnings = []
    
    for lang in lang_weights:
        for cls in lang_weights[lang]:
            w1 = lang_weights[lang][cls]['1']
            w0 = lang_weights[lang][cls]['0']
            
            # Check weight ratio sanity
            ratio = w1 / w0
            if ratio > 30:
                raise ValueError(
                    f"Dangerous weight ratio {ratio:.1f}x for {lang} {cls}. "
                    f"Weight_1={w1:.3f}, Weight_0={w0:.3f}"
                )
            elif ratio > 20:
                warnings.append(
                    f"High weight ratio {ratio:.1f}x for {lang} {cls}"
                )
            
            # Check focal parameter boundaries
            metadata = lang_weights[lang][cls]['calculation_metadata']
            gamma = metadata.get('gamma', 0.0)
            alpha = metadata.get('alpha', 0.0)
            
            if gamma > 5.0:
                raise ValueError(
                    f"Unsafe gamma={gamma:.1f} for {lang} {cls}. "
                    f"Must be <= 5.0"
                )
            elif gamma > 4.0:
                warnings.append(
                    f"High gamma={gamma:.1f} for {lang} {cls}"
                )
                
            if alpha > 0.9:
                raise ValueError(
                    f"Unsafe alpha={alpha:.2f} for {lang} {cls}. "
                    f"Must be < 0.9"
                )
            elif alpha > 0.7:
                warnings.append(
                    f"High alpha={alpha:.2f} for {lang} {cls}"
                )
            
            # Check for combined risk factors
            if gamma > 3.0 and ratio > 15:
                warnings.append(
                    f"Risky combination for {lang} {cls}: "
                    f"gamma={gamma:.1f}, ratio={ratio:.1f}x"
                )
    
    return warnings

def compute_language_weights(df: pd.DataFrame) -> Dict:
    """
    Compute weights with inter-language normalization to ensure consistent 
    weighting across languages while preserving relative class relationships.
    """
    # Validate dataset balance first
    if not validate_dataset_balance(df):
        logging.warning("Proceeding with imbalanced dataset - weights may need manual adjustment")
    
    lang_weights = {}
    toxicity_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
    
    # First pass: calculate raw weights for each language and class
    logging.info("\nFirst pass: Calculating raw weights")
    for lang in df['lang'].unique():
        logging.info(f"\nProcessing language: {lang}")
        lang_df = df[df['lang'] == lang]
        lang_weights[lang] = {}
        
        for col in toxicity_columns:
            y = lang_df[col].values.astype(np.int32)
            support_0 = int((y == 0).sum())
            support_1 = int((y == 1).sum())
            
            params = get_language_specific_params(lang, col)
            weights = calculate_safe_weights(
                support_0=support_0,
                support_1=support_1,
                max_weight=params['max_weight'],
                min_weight=params['min_weight'],
                gamma=params['gamma'],
                alpha=params['alpha'],
                boost_factor=params['boost_factor'],
                lang=lang,
                toxicity_type=col
            )
            lang_weights[lang][col] = weights
            
            # Log initial weights
            logging.info(f"  {col} - Initial weights:")
            logging.info(f"    Class 0: {weights['0']:.3f}, samples: {support_0:,}")
            logging.info(f"    Class 1: {weights['1']:.3f}, samples: {support_1:,}")
    
    # Second pass: normalize weights across languages
    logging.info("\nSecond pass: Normalizing weights across languages")
    for col in toxicity_columns:
        # Find maximum weight for this toxicity type across all languages
        max_weight = max(
            lang_weights[lang][col]['1'] 
            for lang in lang_weights
        )
        
        if max_weight > 0:  # Prevent division by zero
            logging.info(f"\nNormalizing {col}:")
            logging.info(f"  Maximum weight across languages: {max_weight:.3f}")
            
            # Normalize weights for each language
            for lang in lang_weights:
                original_weight = lang_weights[lang][col]['1']
                
                # Normalize and rescale
                normalized_weight = (original_weight / max_weight) * 15.0
                
                # Update weight while preserving metadata
                lang_weights[lang][col]['raw_weight_1'] = original_weight
                lang_weights[lang][col]['1'] = round(normalized_weight, 3)
                
                # Add normalization info to metadata
                lang_weights[lang][col]['calculation_metadata'].update({
                    'normalization': {
                        'original_weight': round(float(original_weight), 3),
                        'max_weight_across_langs': round(float(max_weight), 3),
                        'normalization_factor': round(float(15.0 / max_weight), 3)
                    }
                })
                
                # Log normalization results
                logging.info(f"  {lang}: {original_weight:.3f}{normalized_weight:.3f}")
    
    # Validate final weights
    logging.info("\nValidating final weights:")
    for col in toxicity_columns:
        weights_range = [
            lang_weights[lang][col]['1']
            for lang in lang_weights
        ]
        logging.info(f"  {col}: range [{min(weights_range):.3f}, {max(weights_range):.3f}]")
    
    # Validate weights meet safety criteria
    validation_warnings = validate_weights(lang_weights)
    if validation_warnings:
        logging.warning("\nWeight validation warnings:")
        for warning in validation_warnings:
            logging.warning(f"  {warning}")
    
    # Check cross-language consistency
    consistency_warnings = check_cross_language_consistency(lang_weights)
    if consistency_warnings:
        logging.warning("\nCross-language consistency warnings:")
        for warning in consistency_warnings:
            logging.warning(f"  {warning}")
    
    return lang_weights

def main():
    # Load dataset
    input_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_AUGMENTED.csv'
    logging.info(f"Loading dataset from {input_file}")
    df = pd.read_csv(input_file)
    
    # Compute weights
    lang_weights = compute_language_weights(df)
    
    # Add metadata
    weights_data = {
        "metadata": {
            "total_samples": len(df),
            "language_distribution": df['lang'].value_counts().to_dict(),
            "weight_calculation": {
                "method": "focal_loss_with_adaptive_scaling",
                "parameters": {
                    "default_max_weight": 15.0,
                    "default_min_weight": 0.5,
                    "language_specific_adjustments": True
                }
            }
        },
        "weights": lang_weights
    }
    
    # Save weights
    output_file = 'weights/language_class_weights.json'
    logging.info(f"\nSaving weights to {output_file}")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(weights_data, f, indent=2, ensure_ascii=False)
    
    logging.info("\nWeight calculation complete!")
    
    # Print summary statistics
    logging.info("\nSummary of adjustments made:")
    for lang in lang_weights:
        for col in ['threat', 'identity_hate']:
            if col in lang_weights[lang]:
                weight = lang_weights[lang][col]['1']
                raw = lang_weights[lang][col]['raw_weight_1']
                if raw != weight:
                    logging.info(f"{lang} {col}: Adjusted from {raw:.2f}× to {weight:.2f}×")

if __name__ == "__main__":
    main()