Yahir commited on
Commit
781b469
Β·
verified Β·
1 Parent(s): 17e8c28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -61
app.py CHANGED
@@ -1,9 +1,8 @@
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
 
4
- client = InferenceClient(
5
- "google/gemma-7b-it"
6
- )
7
 
8
  def format_prompt(message, history):
9
  prompt = ""
@@ -14,93 +13,57 @@ def format_prompt(message, history):
14
  prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
15
  return prompt
16
 
17
- def generate(
18
- prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
- ):
20
  if not history:
21
  history = []
22
- hist_len=0
23
- if history:
24
- hist_len=len(history)
25
- print(hist_len)
26
-
27
- temperature = float(temperature)
28
- if temperature < 1e-2:
29
- temperature = 1e-2
30
  top_p = float(top_p)
31
 
32
  generate_kwargs = dict(
33
  temperature=temperature,
34
- max_new_tokens=max_new_tokens,
35
  top_p=top_p,
36
- repetition_penalty=repetition_penalty,
37
  do_sample=True,
38
  seed=42,
39
  )
40
 
41
  formatted_prompt = format_prompt(prompt, history)
42
 
 
43
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
 
 
44
  output = ""
45
-
46
  for response in stream:
47
  output += response.token.text
48
  yield output
49
  return output
50
 
51
-
52
- additional_inputs=[
53
- gr.Slider(
54
- label="Temperature",
55
- value=0.9,
56
- minimum=0.0,
57
- maximum=1.0,
58
- step=0.05,
59
- interactive=True,
60
- info="Higher values produce more diverse outputs",
61
- ),
62
- gr.Slider(
63
- label="Max new tokens",
64
- value=512,
65
- minimum=0,
66
- maximum=1048,
67
- step=64,
68
- interactive=True,
69
- info="The maximum numbers of new tokens",
70
- ),
71
- gr.Slider(
72
- label="Top-p (nucleus sampling)",
73
- value=0.90,
74
- minimum=0.0,
75
- maximum=1,
76
- step=0.05,
77
- interactive=True,
78
- info="Higher values sample more low-probability tokens",
79
- ),
80
- gr.Slider(
81
- label="Repetition penalty",
82
- value=1.2,
83
- minimum=1.0,
84
- maximum=2.0,
85
- step=0.05,
86
- interactive=True,
87
- info="Penalize repeated tokens",
88
- )
89
  ]
90
 
91
  # Create a Chatbot object with the desired height
92
- chatbot = gr.Chatbot(height=450,
93
- layout="bubble",
94
- placeholder="Type here to chat...")
95
 
96
  with gr.Blocks() as demo:
 
97
  gr.HTML("<h1><center>πŸ€– Google-Gemma-7B-Chat πŸ’¬<h1><center>")
 
 
98
  gr.ChatInterface(
99
- generate,
100
- chatbot=chatbot, # Use the created Chatbot object
101
  additional_inputs=additional_inputs,
102
  examples=[["What is the meaning of life?"], ["Tell me something about Mt Fuji."]],
103
- placeholder="Type here to chat..."
104
  )
105
 
 
106
  demo.queue().launch(debug=True)
 
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
 
4
+ # Initialize the InferenceClient
5
+ client = InferenceClient("google/gemma-7b-it")
 
6
 
7
  def format_prompt(message, history):
8
  prompt = ""
 
13
  prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
14
  return prompt
15
 
16
+ def generate_response(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
17
+ # Check if history is empty
 
18
  if not history:
19
  history = []
20
+ # Ensure temperature is within a valid range
21
+ temperature = max(1e-2, float(temperature))
 
 
 
 
 
 
22
  top_p = float(top_p)
23
 
24
  generate_kwargs = dict(
25
  temperature=temperature,
26
+ max_new_tokens=int(max_new_tokens),
27
  top_p=top_p,
28
+ repetition_penalty=float(repetition_penalty),
29
  do_sample=True,
30
  seed=42,
31
  )
32
 
33
  formatted_prompt = format_prompt(prompt, history)
34
 
35
+ # Use the InferenceClient for text generation
36
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
37
+
38
+ # Concatenate the generated responses
39
  output = ""
 
40
  for response in stream:
41
  output += response.token.text
42
  yield output
43
  return output
44
 
45
+ # Additional input sliders for responsiveness
46
+ additional_inputs = [
47
+ gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
48
+ gr.Slider(label="Max new tokens", value=512, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens"),
49
+ gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
50
+ gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ]
52
 
53
  # Create a Chatbot object with the desired height
54
+ chatbot = gr.Chatbot(height=450, layout="bubble")
 
 
55
 
56
  with gr.Blocks() as demo:
57
+ # Display a title
58
  gr.HTML("<h1><center>πŸ€– Google-Gemma-7B-Chat πŸ’¬<h1><center>")
59
+
60
+ # Use ChatInterface for user interaction
61
  gr.ChatInterface(
62
+ generate_response,
63
+ chatbot=chatbot,
64
  additional_inputs=additional_inputs,
65
  examples=[["What is the meaning of life?"], ["Tell me something about Mt Fuji."]],
 
66
  )
67
 
68
+ # Launch the Gradio interface
69
  demo.queue().launch(debug=True)