gabrielclark3330 commited on
Commit
7eeefc1
·
1 Parent(s): e9efc05

Add cuda and sampling pram

Browse files
Files changed (1) hide show
  1. app.py +28 -17
app.py CHANGED
@@ -4,40 +4,51 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  from huggingface_hub import login
6
 
7
- login(token=os.getenv('HF_TOKEN'))
8
-
9
  # Load the tokenizer and model
10
  tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
11
  model = AutoModelForCausalLM.from_pretrained(
12
  "Zyphra/Zamba2-7B",
13
- device_map="auto", # Automatically handles device placement
14
  torch_dtype=torch.bfloat16
15
  )
16
 
17
- def generate_response(input_text):
18
- input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
 
 
 
 
19
  outputs = model.generate(
20
- **input_ids,
21
- max_new_tokens=500,
22
  do_sample=True,
23
- temperature=0.7,
24
- top_k=50,
25
- top_p=0.9,
26
- repetition_penalty=1.2,
27
- num_beams=5,
28
- length_penalty=1.0,
29
  num_return_sequences=1
30
  )
31
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
32
  return response
33
 
34
- # Create the Gradio interface
35
  demo = gr.Interface(
36
  fn=generate_response,
37
- inputs=gr.Textbox(lines=5, placeholder="Enter your question here..."),
38
- outputs=gr.Textbox(),
 
 
 
 
 
 
 
 
 
39
  title="Zamba2-7B Model",
40
- description="Ask Zamba2 7B a question."
41
  )
42
 
43
  if __name__ == "__main__":
 
4
  import torch
5
  from huggingface_hub import login
6
 
 
 
7
  # Load the tokenizer and model
8
  tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
9
  model = AutoModelForCausalLM.from_pretrained(
10
  "Zyphra/Zamba2-7B",
11
+ device_map="cuda", # Automatically handles device placement
12
  torch_dtype=torch.bfloat16
13
  )
14
 
15
+ # Define the function to generate responses
16
+ def generate_response(input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty):
17
+ # Tokenize and move input to model's device
18
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
19
+
20
+ # Generate response using specified parameters
21
  outputs = model.generate(
22
+ input_ids=input_ids,
23
+ max_new_tokens=max_new_tokens,
24
  do_sample=True,
25
+ temperature=temperature,
26
+ top_k=top_k,
27
+ top_p=top_p,
28
+ repetition_penalty=repetition_penalty,
29
+ num_beams=num_beams,
30
+ length_penalty=length_penalty,
31
  num_return_sequences=1
32
  )
33
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
  return response
35
 
36
+ # Create Gradio interface with adjustable parameters
37
  demo = gr.Interface(
38
  fn=generate_response,
39
+ inputs=[
40
+ gr.Textbox(lines=1, placeholder="Enter a text to prepend...", label="Input Text"),
41
+ gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens"),
42
+ gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature"),
43
+ gr.Slider(1, 100, step=1, value=50, label="Top K"),
44
+ gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P"),
45
+ gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty"),
46
+ gr.Slider(1, 10, step=1, value=5, label="Number of Beams"),
47
+ gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
48
+ ],
49
+ outputs=gr.Textbox(label="Generated Response"),
50
  title="Zamba2-7B Model",
51
+ description="Ask Zamba2 7B a question with customizable parameters."
52
  )
53
 
54
  if __name__ == "__main__":