|
--- |
|
license: mit |
|
datasets: Kwaai/toxic_classification |
|
tags: |
|
- PPO |
|
- RLHF |
|
pipeline_tag: text-generation |
|
--- |
|
Aligning the model using Proximal Policy Optimization (PPO). The goal is to train the model to generate non-toxic reviews. The training process utilizes the `trl` library for reinforcement learning, the `transformers` library for model handling, and `datasets` for dataset management. |
|
Implementation code is available here: [GitHub](https://github.com/Kwaai-AI-Lab/kwaai-alignment/tree/main/Implementations/GPT2_NonToxic) |
|
```python |
|
# Load model and tokenizer directly |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Kwaai/GPT2_NonToxic") |
|
model = AutoModelForCausalLM.from_pretrained("Kwaai/GPT2_NonToxic") |
|
|
|
# Example usage |
|
input_text = "you are toxic!" |
|
inputs = tokenizer(input_text, return_tensors='pt') |
|
outputs = model.generate(**inputs) |
|
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
``` |