AdoubleLen commited on
Commit
bdbb322
·
verified ·
1 Parent(s): 19fc4ad

Upload model

Browse files
Files changed (3) hide show
  1. README.md +193 -0
  2. adapter_config.json +20 -0
  3. adapter_model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: peft
3
+ ---
4
+ <div style="text-align: center">
5
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png">
6
+ </div>
7
+
8
+ # TRL - Transformer Reinforcement Learning
9
+ > Full stack transformer language models with reinforcement learning.
10
+
11
+ <p align="center">
12
+ <a href="https://github.com/huggingface/trl/blob/main/LICENSE">
13
+ <img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue">
14
+ </a>
15
+ <a href="https://huggingface.co/docs/trl/index">
16
+ <img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_message=online">
17
+ </a>
18
+ <a href="https://github.com/huggingface/trl/releases">
19
+ <img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg">
20
+ </a>
21
+ </p>
22
+
23
+
24
+ ## What is it?
25
+
26
+ <div style="text-align: center">
27
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
28
+ </div>
29
+
30
+ `trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point, most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
31
+
32
+ **Highlights:**
33
+
34
+ - [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
35
+ - [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
36
+ - [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
37
+ - [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
38
+ - [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
39
+
40
+ ## How PPO works
41
+ Fine-tuning a language model via PPO consists of roughly three steps:
42
+
43
+ 1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
44
+ 2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
45
+ 3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
46
+
47
+ This process is illustrated in the sketch below:
48
+
49
+
50
+ <div style="text-align: center">
51
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
52
+ <p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
53
+ </div>
54
+
55
+ ## Installation
56
+
57
+ ### Python package
58
+ Install the library with pip:
59
+ ```bash
60
+ pip install trl
61
+ ```
62
+
63
+ ### From source
64
+ If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
65
+ ```bash
66
+ git clone https://github.com/huggingface/trl.git
67
+ cd trl/
68
+ pip install .
69
+ ```
70
+
71
+ If you wish to develop TRL, you should install in editable mode:
72
+ ```bash
73
+ pip install -e .
74
+ ```
75
+
76
+ ## How to use
77
+
78
+ ### `SFTTrainer`
79
+
80
+ This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
81
+
82
+ ```python
83
+ # imports
84
+ from datasets import load_dataset
85
+ from trl import SFTTrainer
86
+
87
+ # get dataset
88
+ dataset = load_dataset("imdb", split="train")
89
+
90
+ # get trainer
91
+ trainer = SFTTrainer(
92
+ "facebook/opt-350m",
93
+ train_dataset=dataset,
94
+ dataset_text_field="text",
95
+ max_seq_length=512,
96
+ )
97
+
98
+ # train
99
+ trainer.train()
100
+ ```
101
+
102
+ ### `RewardTrainer`
103
+
104
+ This is a basic example on how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
105
+
106
+ ```python
107
+ # imports
108
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
109
+ from trl import RewardTrainer
110
+
111
+ # load model and dataset - dataset needs to be in a specific format
112
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
113
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
114
+
115
+ ...
116
+
117
+ # load trainer
118
+ trainer = RewardTrainer(
119
+ model=model,
120
+ tokenizer=tokenizer,
121
+ train_dataset=dataset,
122
+ )
123
+
124
+ # train
125
+ trainer.train()
126
+ ```
127
+
128
+ ### `PPOTrainer`
129
+
130
+ This is a basic example on how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
131
+
132
+ ```python
133
+ # imports
134
+ import torch
135
+ from transformers import AutoTokenizer
136
+ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
137
+ from trl.core import respond_to_batch
138
+
139
+ # get models
140
+ model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
141
+ model_ref = create_reference_model(model)
142
+
143
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
144
+
145
+ # initialize trainer
146
+ ppo_config = PPOConfig(
147
+ batch_size=1,
148
+ )
149
+
150
+ # encode a query
151
+ query_txt = "This morning I went to the "
152
+ query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
153
+
154
+ # get model response
155
+ response_tensor = respond_to_batch(model, query_tensor)
156
+
157
+ # create a ppo trainer
158
+ ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
159
+
160
+ # define a reward for response
161
+ # (this could be any reward such as human feedback or output from another model)
162
+ reward = [torch.tensor(1.0)]
163
+
164
+ # train model for one step with ppo
165
+ train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
166
+ ```
167
+
168
+ ## References
169
+
170
+ ### Proximal Policy Optimisation
171
+ The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
172
+
173
+ ### Language models
174
+ The language models utilize the `transformers` library by 🤗 Hugging Face.
175
+
176
+ ## Citation
177
+
178
+ ```bibtex
179
+ @misc{vonwerra2022trl,
180
+ author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang},
181
+ title = {TRL: Transformer Reinforcement Learning},
182
+ year = {2020},
183
+ publisher = {GitHub},
184
+ journal = {GitHub repository},
185
+ howpublished = {\url{https://github.com/huggingface/trl}}
186
+ }
187
+ ```
188
+ ## Training procedure
189
+
190
+ ### Framework versions
191
+
192
+
193
+ - PEFT 0.5.0
adapter_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": null,
3
+ "base_model_name_or_path": "shibing624/gpt2-dialogbot-base-chinese",
4
+ "bias": "none",
5
+ "fan_in_fan_out": true,
6
+ "inference_mode": true,
7
+ "init_lora_weights": true,
8
+ "layers_pattern": null,
9
+ "layers_to_transform": null,
10
+ "lora_alpha": 32,
11
+ "lora_dropout": 0.1,
12
+ "modules_to_save": null,
13
+ "peft_type": "LORA",
14
+ "r": 16,
15
+ "revision": null,
16
+ "target_modules": [
17
+ "c_attn"
18
+ ],
19
+ "task_type": "CAUSAL_LM"
20
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c55a49dc756d0c2c341d4ba8ac47a7fa8f95339df5fa1a076b5e9098635db4a
3
+ size 1968648