SpicyMelonYT commited on
Commit
9ad7da7
·
1 Parent(s): 7cabf61

app code change for train token

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -2,11 +2,12 @@ import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
4
  from datasets import load_dataset
 
5
 
6
  """
7
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
8
  """
9
- client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
10
 
11
 
12
  def respond(
@@ -42,13 +43,16 @@ def respond(
42
  yield response
43
 
44
 
45
- def train_model():
 
 
 
46
  # Load dataset
47
- dataset = load_dataset('json', data_files={
48
- 'train': 'training_set.json'})
49
 
50
  # Load model
51
- model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')
 
52
 
53
  # Define training arguments
54
  training_args = TrainingArguments(
@@ -64,7 +68,8 @@ def train_model():
64
  model=model,
65
  args=training_args,
66
  train_dataset=dataset['train'],
67
- eval_dataset=dataset['test']
 
68
  )
69
 
70
  # Start training
@@ -89,20 +94,17 @@ with demo:
89
  step=1, label="Max new tokens"),
90
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7,
91
  step=0.1, label="Temperature"),
92
- gr.Slider(
93
- minimum=0.1,
94
- maximum=1.0,
95
- value=0.95,
96
- step=0.05,
97
- label="Top-p (nucleus sampling)",
98
- ),
99
  ],
100
  )
101
  with gr.Tab("Train"):
 
102
  train_button = gr.Button("Start Training")
103
  train_output = gr.Textbox(label="Training Output")
104
 
105
- train_button.click(train_model, outputs=train_output)
 
106
 
107
  if __name__ == "__main__":
108
  demo.launch()
 
2
  from huggingface_hub import InferenceClient
3
  from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
4
  from datasets import load_dataset
5
+ import os
6
 
7
  """
8
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
9
  """
10
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
11
 
12
 
13
  def respond(
 
43
  yield response
44
 
45
 
46
+ def train_model(hf_token):
47
+ # Set the Hugging Face token as an environment variable
48
+ os.environ["HUGGINGFACE_TOKEN"] = hf_token
49
+
50
  # Load dataset
51
+ dataset = load_dataset('json', data_files='dataset.jsonl')
 
52
 
53
  # Load model
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ 'meta-llama/Meta-Llama-3-8B-Instruct', use_auth_token=hf_token)
56
 
57
  # Define training arguments
58
  training_args = TrainingArguments(
 
68
  model=model,
69
  args=training_args,
70
  train_dataset=dataset['train'],
71
+ # Using train as eval for this simple example
72
+ eval_dataset=dataset['train']
73
  )
74
 
75
  # Start training
 
94
  step=1, label="Max new tokens"),
95
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7,
96
  step=0.1, label="Temperature"),
97
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95,
98
+ step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
99
  ],
100
  )
101
  with gr.Tab("Train"):
102
+ hf_token = gr.Textbox(label="Hugging Face Token", type="password")
103
  train_button = gr.Button("Start Training")
104
  train_output = gr.Textbox(label="Training Output")
105
 
106
+ train_button.click(fn=train_model, inputs=hf_token,
107
+ outputs=train_output)
108
 
109
  if __name__ == "__main__":
110
  demo.launch()