ramalMr commited on
Commit
43561b8
·
verified ·
1 Parent(s): cd650c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -61
app.py CHANGED
@@ -1,22 +1,21 @@
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
 
3
 
4
- client = InferenceClient(
5
- "mistralai/Mixtral-8x7B-Instruct-v0.1"
6
- )
7
-
8
 
 
9
  def format_prompt(message, history):
10
- prompt = "<s>"
11
- for user_prompt, bot_response in history:
12
- prompt += f"[INST] {user_prompt} [/INST]"
13
- prompt += f" {bot_response}</s> "
14
- prompt += f"[INST] {message} [/INST]"
15
- return prompt
16
 
17
- def generate(
18
- prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
- ):
20
  temperature = float(temperature)
21
  if temperature < 1e-2:
22
  temperature = 1e-2
@@ -31,7 +30,10 @@ def generate(
31
  seed=42,
32
  )
33
 
 
34
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
 
 
35
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
36
  output = ""
37
 
@@ -40,58 +42,43 @@ def generate(
40
  yield output
41
  return output
42
 
43
-
44
  additional_inputs=[
45
- gr.Textbox(
46
- label="System Prompt",
47
- max_lines=1,
48
- interactive=True,
49
- ),
50
- gr.Slider(
51
- label="Temperature",
52
- value=0.9,
53
- minimum=0.0,
54
- maximum=1.0,
55
- step=0.05,
56
- interactive=True,
57
- info="Higher values produce more diverse outputs",
58
- ),
59
- gr.Slider(
60
- label="Max new tokens",
61
- value=256,
62
- minimum=0,
63
- maximum=5120,
64
- step=64,
65
- interactive=True,
66
- info="The maximum numbers of new tokens",
67
- ),
68
- gr.Slider(
69
- label="Top-p (nucleus sampling)",
70
- value=0.90,
71
- minimum=0.0,
72
- maximum=1,
73
- step=0.05,
74
- interactive=True,
75
- info="Higher values sample more low-probability tokens",
76
- ),
77
- gr.Slider(
78
- label="Repetition penalty",
79
- value=1.2,
80
- minimum=1.0,
81
- maximum=2.0,
82
- step=0.05,
83
- interactive=True,
84
- info="Penalize repeated tokens",
85
- )
86
  ]
87
 
 
 
 
 
 
 
 
 
 
88
 
89
-
90
  gr.ChatInterface(
91
  fn=generate,
92
- chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
93
- additional_inputs=additional_inputs,
 
 
 
 
94
  title="Synthetic-data-generation-aze",
95
- concurrency_limit=20,
96
- ).launch(show_api=False)
97
-
 
 
 
 
 
 
 
 
 
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
+ import pandas as pd
4
 
5
+ # Inference client initialization
6
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
 
7
 
8
+ # Function to format the prompt
9
  def format_prompt(message, history):
10
+ prompt = "<s>"
11
+ for user_prompt, bot_response in history:
12
+ prompt += f"[INST] {user_prompt} [/INST]"
13
+ prompt += f" {bot_response}</s> "
14
+ prompt += f"[INST] {message} [/INST]"
15
+ return prompt
16
 
17
+ # Function to generate text based on prompt and history
18
+ def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
 
19
  temperature = float(temperature)
20
  if temperature < 1e-2:
21
  temperature = 1e-2
 
30
  seed=42,
31
  )
32
 
33
+ # Format the prompt
34
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
35
+
36
+ # Generate text using InferenceClient
37
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
38
  output = ""
39
 
 
42
  yield output
43
  return output
44
 
45
+ # Additional input components for Gradio interface
46
  additional_inputs=[
47
+ gr.File(label="Upload CSV or Document", type="upload", accept=".csv,.txt", max_size=2147483648), # Max file size is 2 GB
48
+ gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
49
+ gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
50
+ gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
51
+ gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  ]
53
 
54
+ # Function to read uploaded CSV or Document
55
+ def read_file(file):
56
+ if file is None:
57
+ return None
58
+ elif file.name.endswith('.csv'):
59
+ return pd.read_csv(file)
60
+ elif file.name.endswith('.txt'):
61
+ with open(file.name, 'r') as f:
62
+ return f.read()
63
 
64
+ # Gradio Chat Interface
65
  gr.ChatInterface(
66
  fn=generate,
67
+ inputs=[
68
+ gr.Textbox(label="Prompt"),
69
+ gr.Textbox(label="History", placeholder="User1: Hello\nBot: Hi there!\nUser1: How are you?"),
70
+ gr.Textbox(label="System Prompt"),
71
+ ],
72
+ outputs=gr.Textbox(label="Response"),
73
  title="Synthetic-data-generation-aze",
74
+ additional_inputs=additional_inputs,
75
+ examples=[
76
+ ["What is the capital of France?", "Paris", "Ask me anything"],
77
+ ["How are you?", "I'm good, thank you!", "User"],
78
+ ],
79
+ allow_flagging=False,
80
+ allow_upvoting=False,
81
+ allow_duplicate_of_same_input=False,
82
+ flagging_options=["Inappropriate", "Incorrect", "Offensive"],
83
+ thumbs=None,
84
+ ).launch()