SpicyMelonYT commited on
Commit
f934c1a
·
1 Parent(s): 9ad7da7

another app fix

Browse files
Files changed (1) hide show
  1. app.py +15 -13
app.py CHANGED
@@ -2,12 +2,11 @@ import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
4
  from datasets import load_dataset
5
- import os
6
 
7
  """
8
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
9
  """
10
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
11
 
12
 
13
  def respond(
@@ -43,16 +42,16 @@ def respond(
43
  yield response
44
 
45
 
46
- def train_model(hf_token):
47
- # Set the Hugging Face token as an environment variable
48
  os.environ["HUGGINGFACE_TOKEN"] = hf_token
49
 
50
  # Load dataset
51
- dataset = load_dataset('json', data_files='dataset.jsonl')
 
52
 
53
  # Load model
54
- model = AutoModelForCausalLM.from_pretrained(
55
- 'meta-llama/Meta-Llama-3-8B-Instruct', use_auth_token=hf_token)
56
 
57
  # Define training arguments
58
  training_args = TrainingArguments(
@@ -68,8 +67,7 @@ def train_model(hf_token):
68
  model=model,
69
  args=training_args,
70
  train_dataset=dataset['train'],
71
- # Using train as eval for this simple example
72
- eval_dataset=dataset['train']
73
  )
74
 
75
  # Start training
@@ -94,8 +92,13 @@ with demo:
94
  step=1, label="Max new tokens"),
95
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7,
96
  step=0.1, label="Temperature"),
97
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95,
98
- step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
99
  ],
100
  )
101
  with gr.Tab("Train"):
@@ -103,8 +106,7 @@ with demo:
103
  train_button = gr.Button("Start Training")
104
  train_output = gr.Textbox(label="Training Output")
105
 
106
- train_button.click(fn=train_model, inputs=hf_token,
107
- outputs=train_output)
108
 
109
  if __name__ == "__main__":
110
  demo.launch()
 
2
  from huggingface_hub import InferenceClient
3
  from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
4
  from datasets import load_dataset
 
5
 
6
  """
7
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
8
  """
9
+ client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
10
 
11
 
12
  def respond(
 
42
  yield response
43
 
44
 
45
+ def train_model():
46
+
47
  os.environ["HUGGINGFACE_TOKEN"] = hf_token
48
 
49
  # Load dataset
50
+ dataset = load_dataset('json', data_files={
51
+ 'train': 'training_set.json'})
52
 
53
  # Load model
54
+ model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')
 
55
 
56
  # Define training arguments
57
  training_args = TrainingArguments(
 
67
  model=model,
68
  args=training_args,
69
  train_dataset=dataset['train'],
70
+ eval_dataset=dataset['test']
 
71
  )
72
 
73
  # Start training
 
92
  step=1, label="Max new tokens"),
93
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7,
94
  step=0.1, label="Temperature"),
95
+ gr.Slider(
96
+ minimum=0.1,
97
+ maximum=1.0,
98
+ value=0.95,
99
+ step=0.05,
100
+ label="Top-p (nucleus sampling)",
101
+ ),
102
  ],
103
  )
104
  with gr.Tab("Train"):
 
106
  train_button = gr.Button("Start Training")
107
  train_output = gr.Textbox(label="Training Output")
108
 
109
+ train_button.click(train_model, outputs=train_output)
 
110
 
111
  if __name__ == "__main__":
112
  demo.launch()