File size: 7,037 Bytes
ed0dca2
a40632d
bdf3e70
38d7f73
 
2f1a468
 
bdf3e70
30da7cc
 
9e53c43
b1cf10f
30da7cc
b1cf10f
a40632d
 
42cd34a
30da7cc
42cd34a
30da7cc
 
 
 
2f1a468
30da7cc
 
 
 
 
 
 
ab084df
 
 
 
 
2f1a468
 
ab084df
 
 
 
2f1a468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30da7cc
2f1a468
 
 
 
 
 
 
 
42cd34a
bdf3e70
2f1a468
 
 
 
 
 
 
 
 
ed0dca2
52589e7
17dfda2
 
c59143e
17dfda2
 
 
ab084df
f830874
 
17dfda2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab084df
 
17dfda2
 
079475b
 
ab084df
17dfda2
 
 
 
 
 
ab084df
 
 
 
 
 
 
 
 
 
761e8ba
ab084df
f830874
bc14b30
 
 
 
 
 
 
 
 
 
 
 
 
 
17dfda2
f58a3a5
717f53b
bc14b30
611772e
e751c8d
c5cebfc
717f53b
e751c8d
bc14b30
17dfda2
 
 
ab084df
17dfda2
f830874
8c9400b
 
2f1a468
 
 
 
 
ab084df
2f1a468
8d7ed18
 
717f53b
2f1a468
 
bc14b30
 
 
 
 
 
 
 
 
8d7ed18
17dfda2
2f1a468
 
 
 
 
17dfda2
 
52589e7
 
 
 
17dfda2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
##################################### Imports ######################################
# Generic imports
import gradio as gr
import json

# Specialized imports
#from utilities.modeling import modeling

# Module imports
from utilities.setup import get_json_cfg
from utilities.templates import prompt_template

########################### Global objects and functions ###########################

conf = get_json_cfg()

def textbox_visibility(radio):
    value = radio
    if value == "Hugging Face Hub Dataset":
        return gr.Dropdown(visible=bool(1))
    else:
        return gr.Dropdown(visible=bool(0))


def upload_visibility(radio):
    value = radio
    if value == "Upload Your Own":
        return gr.UploadButton(visible=bool(1)) #make it visible
    else:
        return gr.UploadButton(visible=bool(0))

#from datasets import load_dataset

def get_predefined_dataset(dataset_name):
    dataset = load_dataset(dataset_name, split = "train")
    return dataset


def get_uploaded_dataset():
    with open(file.name, 'r') as f:
        content = f.read()
    return content[0:100]



def train(model_name, 
          inject_prompt, 
          dataset_predefined,
          peft,
          sft,
          max_seq_length,
          random_seed,
          num_epochs,
          max_steps,
          data_field,
          repository,
          model_out_name):
    """The model call"""

    # Get models
    # trainer = modeling(model_name, max_seq_length, random_seed,
    #                    peft, sft, dataset, data_field)
    # trainer_stats = trainer.train()

    # Return outputs of training.
    
    return f"Hello!! Using model: {model_name} with template: {inject_prompt}"


def submit_weights(model, repository, model_out_name, token):
    """submits model to repository"""
    repo = repository + '/' + model_out_name
    
    model.push_to_hub(repo, token = token)
    tokenizer.push_to_hub(repo, token = token)
    return 0

##################################### App UI #######################################

def main():
    with gr.Blocks() as demo:
    
        ##### Title Block #####
        gr.Markdown("# Instruction Tuning with Unsloth")
    
        ##### Initial Model Inputs #####
        gr.Markdown("### Model Inputs")
        
        # Select Model
        modelnames = conf['model']['choices']
        model_name = gr.Dropdown(label="Supported Models", 
                                 choices=modelnames, 
                                 value=modelnames[0])
        # Prompt template
        inject_prompt = gr.Textbox(label="Prompt Template", 
                                     value=prompt_template())
        # Dataset choice
        dataset_choice = gr.Radio(label="Choose Dataset", 
                                  choices=["Hugging Face Hub Dataset", "Upload Your Own"], 
                                  value="Hugging Face Hub Dataset")
        dataset_predefined = gr.Textbox(label="Hugging Face Hub Dataset", 
                                        value='yahma/alpaca-cleaned', 
                                        visible=True)
        dataset_predefined_load = gr.Button("Upload Dataset")
        dataset_uploaded_load = gr.UploadButton(label="Upload Dataset (csv, jsonl, or txt)", 
                                         file_types=[".csv",".jsonl", ".txt"], 
                                         visible=False)
        data_field = gr.Textbox(label="Dataset Training Field",
                                value=conf['model']['general']["dataset_text_field"])
        data_snippet = gr.Markdown()
        dataset_choice.change(textbox_visibility, 
                              dataset_choice, 
                              dataset_predefined)
        dataset_choice.change(upload_visibility, 
                              dataset_choice, 
                              dataset_upload)
        
        # Dataset button
        dataset_predefined_load.click(fn=get_predefined_dataset
                                  inputs=dataset_predefined_load,
                                  outputs=data_snippet)

        dataset_uploaded_load.click(fn=get_uploaded_dataset,
                                 inputs=dataset_uploaded_load,
                                 outputs=data_snippet)


        ##### Model Parameter Inputs #####
        gr.Markdown("### Model Parameter Selection")
        # Parameters
        max_seq_length = gr.Textbox(label="Maximum sequence length", 
                                     value=conf['model']['general']["max_seq_length"])
        random_seed = gr.Textbox(label="Seed",
                                value=conf['model']['general']["seed"])
        num_epochs = gr.Textbox(label="Training Epochs",
                                value=conf['model']['general']["num_train_epochs"])
        max_steps = gr.Textbox(label="Maximum steps",
                                value=conf['model']['general']["max_steps"])   
        repository = gr.Textbox(label="Repository Name",
                                value=conf['model']['general']["repository"])   
        model_out_name = gr.Textbox(label="Model Output Name",
                                value=conf['model']['general']["model_name"])   

        # Hyperparameters (allow selection, but hide in accordion.)
        with gr.Accordion("Advanced Tuning", open=False):

            sftparams = conf['model']['general']
            # accordion container content
            dict_string = json.dumps(dict(conf['model']['peft']), indent=4)
            peft = gr.Textbox(label="PEFT Parameters (json)", value=dict_string)
            
            dict_string = json.dumps(dict(conf['model']['sft']), indent=4)
            sft = gr.Textbox(label="SFT Parameters (json)", value=dict_string)            
        
        ##### Execution #####
    
        # Setup buttons
        tune_btn = gr.Button("Start Fine Tuning")
        gr.Markdown("### Model Progress")
        # Text output (for now)
        output = gr.Textbox(label="Output") 
        
        
        # Data retrieval
        
        
        # Execute buttons
        tune_btn.click(fn=train, 
                       inputs=[model_name, 
                               inject_prompt, 
                               dataset_predefined,
                               dataset_upload,
                               data_field,
                               peft,
                               sft,
                               max_seq_length,
                               random_seed,
                               num_epochs,
                               max_steps,
                               data_field,
                               repository,
                               model_out_name
                              ],
                       outputs=output)
        # stop button

        # submit button

        
        # Launch baby
        demo.launch()

##################################### Launch #######################################

if __name__ == "__main__":
    main()