SpicyMelonYT commited on
Commit
8245e16
·
1 Parent(s): ea4b878

Add training functionality to Gradio app

Browse files
Files changed (2) hide show
  1. app.py +55 -17
  2. workspace.code-workspace +8 -0
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
3
 
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
9
-
10
  def respond(
11
  message,
12
  history: list[tuple[str, str]],
@@ -39,25 +40,62 @@ def respond(
39
  response += token
40
  yield response
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  """
43
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
  """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
 
 
 
 
 
 
 
 
60
 
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
4
+ from datasets import load_dataset
5
 
6
  """
7
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
8
  """
9
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
10
 
 
11
  def respond(
12
  message,
13
  history: list[tuple[str, str]],
 
40
  response += token
41
  yield response
42
 
43
+ def train_model():
44
+ # Load dataset
45
+ dataset = load_dataset('your_dataset_name')
46
+
47
+ # Load model
48
+ model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B')
49
+
50
+ # Define training arguments
51
+ training_args = TrainingArguments(
52
+ output_dir='./results',
53
+ num_train_epochs=3,
54
+ per_device_train_batch_size=16,
55
+ save_steps=10_000,
56
+ save_total_limit=2,
57
+ )
58
+
59
+ # Initialize Trainer
60
+ trainer = Trainer(
61
+ model=model,
62
+ args=training_args,
63
+ train_dataset=dataset['train'],
64
+ eval_dataset=dataset['test']
65
+ )
66
+
67
+ # Start training
68
+ trainer.train()
69
+ return "Training complete"
70
+
71
  """
72
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
73
  """
74
+ demo = gr.Blocks()
75
+
76
+ with demo:
77
+ gr.Markdown("# Llama3training Chatbot and Model Trainer")
78
+ with gr.Tab("Chat"):
79
+ gr.ChatInterface(
80
+ respond,
81
+ additional_inputs=[
82
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
83
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
84
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
85
+ gr.Slider(
86
+ minimum=0.1,
87
+ maximum=1.0,
88
+ value=0.95,
89
+ step=0.05,
90
+ label="Top-p (nucleus sampling)",
91
+ ),
92
+ ],
93
+ )
94
+ with gr.Tab("Train"):
95
+ train_button = gr.Button("Start Training")
96
+ train_output = gr.Textbox(label="Training Output")
97
 
98
+ train_button.click(train_model, outputs=train_output)
99
 
100
  if __name__ == "__main__":
101
+ demo.launch()
workspace.code-workspace ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "folders": [
3
+ {
4
+ "path": "."
5
+ }
6
+ ],
7
+ "settings": {}
8
+ }