|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig |
|
) |
|
from langdetect import detect |
|
import pandas as pd |
|
import numpy as np |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
import logging |
|
import gc |
|
from typing import List |
|
import json |
|
from datetime import datetime, timedelta |
|
import time |
|
import sys |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.linear_model import LogisticRegression |
|
import joblib |
|
|
|
|
|
log_dir = Path("logs") |
|
log_dir.mkdir(exist_ok=True) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
log_file = log_dir / f"generation_{timestamp}.log" |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s | %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout), |
|
logging.FileHandler(log_file) |
|
] |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.info(f"Starting new run. Log file: {log_file}") |
|
|
|
def log_separator(message: str = ""): |
|
"""Print a separator line with optional message""" |
|
if message: |
|
logger.info("\n" + "="*40 + f" {message} " + "="*40) |
|
else: |
|
logger.info("\n" + "="*100) |
|
|
|
class FastThreatValidator: |
|
"""Fast threat validation using logistic regression""" |
|
def __init__(self, model_path: str = "weights/threat_validator.joblib"): |
|
self.model_path = model_path |
|
if Path(model_path).exists(): |
|
logger.info("Loading fast threat validator...") |
|
model_data = joblib.load(model_path) |
|
self.vectorizer = model_data['vectorizer'] |
|
self.model = model_data['model'] |
|
logger.info("✓ Fast validator loaded") |
|
else: |
|
logger.info("Training fast threat validator...") |
|
self._train_validator() |
|
logger.info("✓ Fast validator trained and saved") |
|
|
|
def _train_validator(self): |
|
"""Train a simple logistic regression model for threat detection""" |
|
|
|
train_df = pd.read_csv("dataset/split/train.csv") |
|
|
|
|
|
X = train_df['comment_text'].fillna('') |
|
y = train_df['threat'] |
|
|
|
|
|
self.vectorizer = TfidfVectorizer( |
|
max_features=10000, |
|
ngram_range=(1, 2), |
|
strip_accents='unicode', |
|
min_df=2 |
|
) |
|
X_vec = self.vectorizer.fit_transform(X) |
|
|
|
|
|
self.model = LogisticRegression( |
|
C=1.0, |
|
class_weight='balanced', |
|
max_iter=200, |
|
n_jobs=-1 |
|
) |
|
self.model.fit(X_vec, y) |
|
|
|
|
|
joblib.dump({ |
|
'vectorizer': self.vectorizer, |
|
'model': self.model |
|
}, self.model_path) |
|
|
|
def validate(self, texts: List[str], threshold: float = 0.6) -> List[bool]: |
|
"""Validate texts using the fast model""" |
|
|
|
X = self.vectorizer.transform(texts) |
|
|
|
|
|
probs = self.model.predict_proba(X)[:, 1] |
|
|
|
|
|
return probs >= threshold |
|
|
|
class ThreatAugmenter: |
|
def __init__(self, seed_samples_path: str = "dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_FINAL.csv"): |
|
log_separator("INITIALIZATION") |
|
|
|
|
|
self.log_file = log_file |
|
|
|
|
|
self.generation_buffer = [] |
|
self.buffer_size = 100 |
|
|
|
|
|
self.num_gpus = torch.cuda.device_count() |
|
if self.num_gpus > 0: |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
logger.info(f"Found {self.num_gpus} GPUs:") |
|
for i in range(self.num_gpus): |
|
mem = torch.cuda.get_device_properties(i).total_memory / 1024**3 |
|
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)") |
|
|
|
|
|
log_separator("LOADING MODELS") |
|
logger.info("Loading Mistral-7B...") |
|
|
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True |
|
) |
|
|
|
self.llm = AutoModelForCausalLM.from_pretrained( |
|
"mistralai/Mistral-7B-Instruct-v0.3", |
|
device_map="balanced", |
|
torch_dtype=torch.float16, |
|
quantization_config=quantization_config, |
|
max_memory={0: "22GB", 1: "22GB"} |
|
) |
|
|
|
self.llm_tokenizer = AutoTokenizer.from_pretrained( |
|
"mistralai/Mistral-7B-Instruct-v0.3", |
|
padding_side="left", |
|
use_fast=True |
|
) |
|
self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token |
|
logger.info("✓ Mistral-7B loaded") |
|
|
|
|
|
self.validator = FastThreatValidator() |
|
logger.info("✓ Fast validator initialized") |
|
|
|
|
|
log_separator("LOADING DATA") |
|
logger.info(f"Loading samples from {seed_samples_path}") |
|
self.seed_df = pd.read_csv(seed_samples_path) |
|
self.en_threat_samples = self.seed_df[ |
|
(self.seed_df['lang'] == 'en') & |
|
(self.seed_df['threat'] == 1) |
|
] |
|
logger.info(f"Found {len(self.en_threat_samples)} EN threat samples") |
|
|
|
|
|
self.max_batch_size = 48 |
|
self.prefetch_factor = 4 |
|
self.num_workers = 8 |
|
|
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
log_separator() |
|
|
|
def generate_prompt(self, seed_text: str) -> str: |
|
"""Generate a prompt for realistic threatening content""" |
|
return f"""<s>[INST] Generate a threatening social media comment that is realistic and intimidating. |
|
|
|
Reference example (generate something different): {seed_text} |
|
|
|
Requirements: |
|
- Make it genuinely threatening |
|
- Use authentic social media language |
|
- Keep it under 50 words |
|
- Must be different from example |
|
|
|
Generate ONLY the comment: [/INST]""" |
|
|
|
def flush_buffer(self): |
|
"""Flush the generation buffer to disk""" |
|
if self.generation_buffer: |
|
try: |
|
with open(self.log_file, 'a', encoding='utf-8') as f: |
|
for entry in self.generation_buffer: |
|
f.write(json.dumps(entry, ensure_ascii=False) + '\n') |
|
self.generation_buffer = [] |
|
except Exception as e: |
|
logger.error(f"Failed to flush buffer: {str(e)}") |
|
|
|
def log_generation(self, seed_text: str, prompt: str, generated_text: str, is_valid: bool): |
|
"""Buffer log generation details""" |
|
log_entry = { |
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
"seed_text": seed_text, |
|
"prompt": prompt, |
|
"generated_text": generated_text, |
|
"is_valid": is_valid |
|
} |
|
|
|
self.generation_buffer.append(log_entry) |
|
|
|
|
|
if len(self.generation_buffer) >= self.buffer_size: |
|
self.flush_buffer() |
|
|
|
def generate_samples(self, prompts: List[str], seed_texts: List[str]) -> List[str]: |
|
try: |
|
with torch.amp.autocast('cuda', dtype=torch.float16): |
|
inputs = self.llm_tokenizer(prompts, return_tensors="pt", padding=True, |
|
truncation=True, max_length=256).to(self.llm.device) |
|
|
|
outputs = self.llm.generate( |
|
**inputs, |
|
max_new_tokens=32, |
|
temperature=0.95, |
|
do_sample=True, |
|
top_p=0.92, |
|
top_k=50, |
|
num_return_sequences=1, |
|
repetition_penalty=1.15, |
|
pad_token_id=self.llm_tokenizer.pad_token_id, |
|
eos_token_id=self.llm_tokenizer.eos_token_id |
|
) |
|
|
|
texts = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=False) |
|
cleaned_texts = [] |
|
valid_count = 0 |
|
|
|
|
|
for idx, text in enumerate(texts): |
|
if "[/INST]" in text and "</s>" in text: |
|
response = text.split("[/INST]")[1].split("</s>")[0].strip() |
|
response = response.strip().strip('"').strip("'") |
|
|
|
word_count = len(response.split()) |
|
if (word_count >= 3 and word_count <= 50 and |
|
not any(x in response.lower() for x in [ |
|
"generate", "requirements:", "reference", |
|
"[inst]", "example" |
|
])): |
|
cleaned_texts.append(response) |
|
valid_count += 1 |
|
|
|
|
|
if valid_count > 0: |
|
logger.info(f"\nBatch Success: {valid_count}/{len(texts)} ({valid_count/len(texts)*100:.1f}%)") |
|
|
|
return cleaned_texts |
|
|
|
except Exception as e: |
|
logger.error(f"Generation error: {str(e)}") |
|
return [] |
|
|
|
def validate_toxicity(self, texts: List[str]) -> torch.Tensor: |
|
"""Validate texts using fast logistic regression""" |
|
if not texts: |
|
return torch.zeros(0, dtype=torch.bool) |
|
|
|
|
|
validation_mask = self.validator.validate(texts) |
|
|
|
|
|
return torch.tensor(validation_mask, dtype=torch.bool, device=self.llm.device) |
|
|
|
def validate_language(self, texts: List[str]) -> List[bool]: |
|
"""Simple language validation""" |
|
return [detect(text) == 'en' for text in texts] |
|
|
|
def augment_dataset(self, target_samples: int = 500, batch_size: int = 32): |
|
"""Main augmentation loop with progress bar and CSV saving""" |
|
try: |
|
start_time = time.time() |
|
logger.info(f"Starting generation: target={target_samples}, batch_size={batch_size}") |
|
generated_samples = [] |
|
stats = { |
|
"total_attempts": 0, |
|
"valid_samples": 0, |
|
"batch_times": [] |
|
} |
|
|
|
|
|
output_dir = Path("dataset/augmented") |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
output_file = output_dir / f"threat_augmented_{timestamp}.csv" |
|
|
|
|
|
pbar = tqdm(total=target_samples, |
|
desc="Generating samples", |
|
unit="samples", |
|
ncols=100, |
|
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') |
|
|
|
while len(generated_samples) < target_samples: |
|
batch_start = time.time() |
|
|
|
seed_texts = self.en_threat_samples['comment_text'].sample(batch_size).tolist() |
|
prompts = [self.generate_prompt(text) for text in seed_texts] |
|
new_samples = self.generate_samples(prompts, seed_texts) |
|
|
|
if not new_samples: |
|
continue |
|
|
|
|
|
batch_time = time.time() - batch_start |
|
stats["batch_times"].append(batch_time) |
|
stats["total_attempts"] += len(new_samples) |
|
prev_len = len(generated_samples) |
|
generated_samples.extend(new_samples) |
|
stats["valid_samples"] = len(generated_samples) |
|
|
|
|
|
pbar.update(len(generated_samples) - prev_len) |
|
|
|
|
|
if len(stats["batch_times"]) % 10 == 0: |
|
success_rate = (stats["valid_samples"] / stats["total_attempts"]) * 100 |
|
avg_batch_time = sum(stats["batch_times"][-20:]) / min(len(stats["batch_times"]), 20) |
|
pbar.set_postfix({ |
|
'Success Rate': f'{success_rate:.1f}%', |
|
'Batch Time': f'{avg_batch_time:.2f}s' |
|
}) |
|
|
|
|
|
if len(generated_samples) % (batch_size * 5) == 0: |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
pbar.close() |
|
|
|
|
|
df = pd.DataFrame({ |
|
'text': generated_samples[:target_samples], |
|
'label': 1, |
|
'source': 'augmented', |
|
'timestamp': timestamp |
|
}) |
|
|
|
|
|
df.to_csv(output_file, index=False) |
|
logger.info(f"\nSaved {len(df)} samples to {output_file}") |
|
|
|
|
|
total_time = str(timedelta(seconds=int(time.time() - start_time))) |
|
logger.info(f"Generation complete: {len(generated_samples)} samples generated in {total_time}") |
|
|
|
return df |
|
|
|
except Exception as e: |
|
logger.error(f"Generation failed: {str(e)}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
augmenter = ThreatAugmenter() |
|
augmented_df = augmenter.augment_dataset(target_samples=500) |