tykiww's picture
Update app.py
36a6459 verified
raw
history blame
7.38 kB
##################################### Imports ######################################
# Generic imports
import gradio as gr
import json
# Specialized imports
#from utilities.modeling import modeling
# Module imports
from utilities.setup import get_json_cfg
from utilities.templates import prompt_template
########################### Global objects and functions ###########################
conf = get_json_cfg()
def textbox_visibility(radio):
value = radio
if value == "Hugging Face Hub Dataset":
return gr.Dropdown(visible=bool(1))
else:
return gr.Dropdown(visible=bool(0))
def textbox_button_visibility(radio):
value = radio
if value == "Hugging Face Hub Dataset":
return gr.Button(visible=bool(1))
else:
return gr.Button(visible=bool(0))
def upload_visibility(radio):
value = radio
if value == "Upload Your Own":
return gr.UploadButton(visible=bool(1)) #make it visible
else:
return gr.UploadButton(visible=bool(0))
from datasets import load_dataset
def get_predefined_dataset(dataset_name):
dataset = load_dataset(dataset_name, split = "train") #dataset_name
return print(dataset[0]['output'][0:100])
def get_uploaded_dataset(file):
print(file.name[0:100])
return 0
def train(model_name,
inject_prompt,
dataset_predefined,
peft,
sft,
max_seq_length,
random_seed,
num_epochs,
max_steps,
data_field,
repository,
model_out_name):
"""The model call"""
# Get models
# trainer = modeling(model_name, max_seq_length, random_seed,
# peft, sft, dataset, data_field)
# trainer_stats = trainer.train()
# Return outputs of training.
return f"Hello!! Using model: {model_name} with template: {inject_prompt}"
def submit_weights(model, repository, model_out_name, token):
"""submits model to repository"""
repo = repository + '/' + model_out_name
model.push_to_hub(repo, token = token)
tokenizer.push_to_hub(repo, token = token)
return 0
##################################### App UI #######################################
def main():
with gr.Blocks() as demo:
##### Title Block #####
gr.Markdown("# Instruction Tuning with Unsloth")
##### Initial Model Inputs #####
gr.Markdown("### Model Inputs")
# Select Model
modelnames = conf['model']['choices']
model_name = gr.Dropdown(label="Supported Models",
choices=modelnames,
value=modelnames[0])
# Prompt template
inject_prompt = gr.Textbox(label="Prompt Template",
value=prompt_template())
# Dataset choice
dataset_choice = gr.Radio(label="Choose Dataset",
choices=["Hugging Face Hub Dataset", "Upload Your Own"],
value="Hugging Face Hub Dataset")
dataset_predefined = gr.Textbox(label="Hugging Face Hub Training Dataset",
value='yahma/alpaca-cleaned',
visible=True)
dataset_predefined_load = gr.Button("Upload Dataset")
dataset_uploaded_load = gr.UploadButton(label="Upload Dataset (.csv, .jsonl, or .txt)",
file_types=[".csv",".jsonl", ".txt"],
visible=False)
file_output = gr.File(visible=False)
data_field = gr.Textbox(label="Dataset Training Field",
value=conf['model']['general']["dataset_text_field"])
data_snippet = gr.Markdown()
dataset_choice.change(textbox_visibility,
dataset_choice,
dataset_predefined)
dataset_choice.change(upload_visibility,
dataset_choice,
dataset_uploaded_load)
dataset_choice.change(textbox_button_visibility,
dataset_choice,
dataset_predefined_load)
# Dataset button
dataset_predefined_load.click(fn=get_predefined_dataset,
inputs=dataset_predefined,
outputs=data_snippet)
dataset_uploaded_load.click(fn=get_uploaded_dataset,
inputs=dataset_uploaded_load,
outputs=data_snippet)
##### Model Parameter Inputs #####
gr.Markdown("### Model Parameter Selection")
# Parameters
max_seq_length = gr.Textbox(label="Maximum sequence length",
value=conf['model']['general']["max_seq_length"])
random_seed = gr.Textbox(label="Seed",
value=conf['model']['general']["seed"])
num_epochs = gr.Textbox(label="Training Epochs",
value=conf['model']['general']["num_train_epochs"])
max_steps = gr.Textbox(label="Maximum steps",
value=conf['model']['general']["max_steps"])
repository = gr.Textbox(label="Repository Name",
value=conf['model']['general']["repository"])
model_out_name = gr.Textbox(label="Model Output Name",
value=conf['model']['general']["model_name"])
# Hyperparameters (allow selection, but hide in accordion.)
with gr.Accordion("Advanced Tuning", open=False):
sftparams = conf['model']['general']
# accordion container content
dict_string = json.dumps(dict(conf['model']['peft']), indent=4)
peft = gr.Textbox(label="PEFT Parameters (json)", value=dict_string)
dict_string = json.dumps(dict(conf['model']['sft']), indent=4)
sft = gr.Textbox(label="SFT Parameters (json)", value=dict_string)
##### Execution #####
# Setup buttons
tune_btn = gr.Button("Start Fine Tuning")
gr.Markdown("### Model Progress")
# Text output (for now)
output = gr.Textbox(label="Output")
# Data retrieval
# Execute buttons
tune_btn.click(fn=train,
inputs=[model_name,
inject_prompt,
dataset_predefined,
peft,
sft,
max_seq_length,
random_seed,
num_epochs,
max_steps,
data_field,
repository,
model_out_name
],
outputs=output)
# stop button
# submit button
# Launch baby
demo.launch()
##################################### Launch #######################################
if __name__ == "__main__":
main()