Gopal2002 commited on
Commit
6ea7bd8
·
verified ·
1 Parent(s): 9b55569

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+
5
+ import transformers
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
7
+ from datasets import load_dataset
8
+ from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
9
+ from trl import DPOTrainer
10
+ import bitsandbytes as bnb
11
+ from google.colab import userdata
12
+ import wandb
13
+
14
+ # Defined in the secrets tab in Google Colab
15
+ # wb_token = "2eae619e4d6f0caef6408a6dc869dd0bfa6595f6"
16
+ hf_token = os.getenv("hf_token")
17
+ wb_token = os.getenv("wb_token")
18
+ wandb.login(key=wb_token)
19
+
20
+
21
+
22
+ # Fine-tune model with DPO
23
+
24
+
25
+ import gradio as gr
26
+
27
+
28
+ def greet(traindata_,output_repo):
29
+ model_name = "HuggingFaceH4/zephyr-7b-gemma-v0.1"
30
+ # new_model = "Gopal2002/zehpyr-gemma-dpo-finetune"
31
+ new_model = output_repo
32
+
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+ tokenizer.pad_token = tokenizer.eos_token
36
+ tokenizer.padding_side = "left"
37
+
38
+
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_name,
41
+ torch_dtype=torch.float16,
42
+ load_in_4bit=True
43
+ )
44
+ model.config.use_cache = False
45
+
46
+ # Reference model
47
+ ref_model = AutoModelForCausalLM.from_pretrained(
48
+ model_name,
49
+ torch_dtype=torch.float16,
50
+ load_in_4bit=True
51
+ )
52
+
53
+ # specify how to quantize the model
54
+ quantization_config = BitsAndBytesConfig(
55
+ load_in_4bit=True,
56
+ bnb_4bit_quant_type="nf4",
57
+ bnb_4bit_compute_dtype=torch.bfloat16,
58
+ )
59
+ device_map = {"": torch.cuda.current_device()} if torch.cuda.is_available() else None
60
+
61
+ # Step 1: load the base model (Mistral-7B in our case) in 4-bit
62
+ model_kwargs = dict(
63
+ # attn_implementation="flash_attention_2", # set this to True if your GPU supports it (Flash Attention drastically speeds up model computations)
64
+ torch_dtype="auto",
65
+ use_cache=False, # set to False as we're going to use gradient checkpointing
66
+ device_map=device_map,
67
+ quantization_config=quantization_config,
68
+ )
69
+ model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
70
+
71
+ # Training arguments
72
+ peft_config = LoraConfig(
73
+ r=16,
74
+ lora_alpha=16,
75
+ lora_dropout=0.05,
76
+ bias="none",
77
+ task_type="CAUSAL_LM",
78
+ target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']
79
+ )
80
+ training_args = TrainingArguments(
81
+ per_device_train_batch_size=4,
82
+ gradient_accumulation_steps=4,
83
+ gradient_checkpointing=True,
84
+ learning_rate=5e-5,
85
+ lr_scheduler_type="cosine",
86
+ max_steps=200,
87
+ save_strategy="no",
88
+ logging_steps=1,
89
+ output_dir=new_model,
90
+ optim="paged_adamw_32bit",
91
+ warmup_steps=100,
92
+ bf16=True,
93
+ report_to="wandb",
94
+ )
95
+
96
+ #load the dataset
97
+ dataset = load_dataset(traindata_, split='train')
98
+
99
+ # dataset = load_dataset('Gopal2002/zephyr-gemma-finetune-dpo', split='train')
100
+
101
+ # Create DPO trainer
102
+ dpo_trainer = DPOTrainer(
103
+ model,
104
+ ref_model=None,
105
+ args=training_args,
106
+ train_dataset=dataset,
107
+ tokenizer=tokenizer,
108
+ peft_config=peft_config,
109
+ beta=0.1,
110
+ max_prompt_length=2048,
111
+ max_length=1536,
112
+ )
113
+ dpo_trainer.train()
114
+ return "Training Done"
115
+
116
+
117
+ with gr.Blocks() as demo:
118
+ traindata_ = gr.Textbox(label="Enter training data repo")
119
+ output_repo = gr.Textbox(label="Enter output model path")
120
+
121
+ output = gr.Textbox(label="Output Box")
122
+ greet_btn = gr.Button("TRAIN")
123
+ greet_btn.click(fn=greet, inputs=[traindata_,output_repo], outputs=output, api_name="greet")
124
+
125
+ demo.launch()