|
--- |
|
license: mit |
|
datasets: |
|
- OxAISH-AL-LLM/wiki_toxic |
|
- textdetox/multilingual_toxic_spans |
|
language: |
|
- en |
|
base_model: |
|
- openai-community/gpt2 |
|
|
|
--- |
|
|
|
# Model Card for Toxic Text GEN |
|
|
|
This model is a decision Tranformer for text generation with controlled toxicity (0-1). |
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
Made using a decision transformer, it can generate toxic sentences based on a toxicity control (defined as reward-to-go/rtg). |
|
|
|
Current text generation is not very coherent due to lack of variety in training data and low compute. |
|
|
|
- **Developed by:** [Ashed00] |
|
- **Finetuned from model:** [GPT-2] |
|
|
|
### Model Sources [optional] |
|
|
|
|
|
- **Repository:** [https://github.com/Ashu-00/NLP-Implementations/tree/main/Decision_Transformer] |
|
- **Demo:** Soon |
|
|
|
## Uses |
|
|
|
Fun, little experiment. |
|
|
|
|
|
## Bias, Risks, and Limitations |
|
|
|
This model is biased based on its training data. I take no responsibility for its generation. |
|
|
|
Most generated text is non-coherent due to lack of variety of training data. |
|
|
|
## How to Get Started with the Model |
|
|
|
```python |
|
|
|
import torch.nn.functional as F |
|
|
|
def generate_conditioned_text2(model, tokenizer, prompt, target_rtg, max_length=50, temperature=1.0, top_k=50): |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
input_ids = inputs["input_ids"].to(device) |
|
attention_mask = inputs["attention_mask"].to(device) |
|
|
|
# Create RTG tensor with the target value for each token in the prompt |
|
rtg = torch.tensor([[target_rtg] * input_ids.shape[1]], dtype=torch.float).to(device) |
|
|
|
seq_length = input_ids.shape[1] |
|
for _ in range(max_length): |
|
with torch.no_grad(): |
|
# Slice rtg to match current sequence length |
|
rtg_current = rtg[:, :seq_length] |
|
outputs = model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
rtg=rtg_current, |
|
return_dict=True |
|
) |
|
|
|
# Get next token logits and apply temperature scaling |
|
next_token_logits = outputs["logits"][:, -1, :] / temperature |
|
|
|
# Apply top-k filtering |
|
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k) |
|
probabilities = F.softmax(top_k_logits, dim=-1) |
|
next_token = top_k_indices[0, torch.multinomial(probabilities, num_samples=1)] |
|
|
|
# Append the predicted token to input_ids and update attention mask |
|
|
|
input_ids = torch.cat([input_ids, next_token], dim=-1) |
|
attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1) |
|
|
|
# Append the target reward for the new token |
|
new_rtg = torch.tensor([[target_rtg]], dtype=torch.float).to(device) |
|
rtg = torch.cat([rtg, new_rtg], dim=1) |
|
|
|
# Stop if EOS token is generated |
|
if next_token.item() == tokenizer.eos_token_id: |
|
break |
|
|
|
seq_length += 1 |
|
|
|
return tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|
|
less_toxic_text = generate_conditioned_text2(model, tokenizer, prompt, target_rtg=1) |
|
more_toxic_text = generate_conditioned_text2(model, tokenizer, prompt, target_rtg=0.0) |
|
avg_toxic = generate_conditioned_text2(model,tokenizer, prompt, target_rtg=0.5 ) |
|
|
|
print("More Toxic Text:", less_toxic_text) |
|
print("Less Toxic Text:", more_toxic_text) |
|
print("Avg Toxic Text:", avg_toxic) |
|
|
|
``` |
|
|
|
## Training Details |
|
|
|
Refer to the github for training datasets and procedure. |