ramalMr commited on
Commit
923f75f
·
verified ·
1 Parent(s): cb69c24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -48
app.py CHANGED
@@ -1,21 +1,22 @@
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
- import pandas as pd
4
 
5
- # Inference client initialization
6
- client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
 
7
 
8
- # Function to format the prompt
9
  def format_prompt(message, history):
10
- prompt = "<s>"
11
- for user_prompt, bot_response in history:
12
- prompt += f"[INST] {user_prompt} [/INST]"
13
- prompt += f" {bot_response}</s> "
14
- prompt += f"[INST] {message} [/INST]"
15
- return prompt
16
 
17
- # Function to generate text based on prompt and history
18
- def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
 
19
  temperature = float(temperature)
20
  if temperature < 1e-2:
21
  temperature = 1e-2
@@ -30,10 +31,12 @@ def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256
30
  seed=42,
31
  )
32
 
33
- # Format the prompt
34
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
35
 
36
- # Generate text using InferenceClient
 
 
 
37
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
38
  output = ""
39
 
@@ -42,44 +45,63 @@ def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256
42
  yield output
43
  return output
44
 
45
- # Additional input components for Gradio interface
46
  additional_inputs=[
47
- gr.File(label="Upload CSV or Document", type="binary"), # Max file size is 2 GB
48
- gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
49
- gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
50
- gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
51
- gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  ]
53
 
54
- # Function to read uploaded CSV or Document
55
- def read_file(file):
56
- if file is None:
57
- return None
58
- elif file.name.endswith('.csv'):
59
- return pd.read_csv(file)
60
- elif file.name.endswith('.txt'):
61
- with open(file.name, 'r') as f:
62
- return f.read()
63
 
64
- # Gradio Chat Interface
65
  gr.ChatInterface(
66
  fn=generate,
67
- inputs=[
68
- gr.inputs.Textbox(label="Prompt"),
69
- gr.inputs.Textbox(label="History", placeholder="User1: Hello\nBot: Hi there!\nUser1: How are you?"),
70
- gr.inputs.Textbox(label="System Prompt"),
71
- gr.inputs.File(label="Upload CSV or Document", type="binary"), # Max file size is 2 GB
72
- ],
73
- outputs=gr.outputs.Textbox(label="Response"),
74
- title="Synthetic-data-generation-aze",
75
  additional_inputs=additional_inputs,
76
- examples=[
77
- ["What is the capital of France?", "Paris", "Ask me anything"],
78
- ["How are you?", "I'm good, thank you!", "User"],
79
- ],
80
- allow_flagging=False,
81
- allow_upvoting=False,
82
- allow_duplicate_of_same_input=False,
83
- flagging_options=["Inappropriate", "Incorrect", "Offensive"],
84
- thumbs=None,
85
- ).launch()
 
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
 
3
 
4
+ client = InferenceClient(
5
+ "mistralai/Mixtral-8x7B-Instruct-v0.1"
6
+ )
7
+
8
 
 
9
  def format_prompt(message, history):
10
+ prompt = "<s>"
11
+ for user_prompt, bot_response in history:
12
+ prompt += f"[INST] {user_prompt} [/INST]"
13
+ prompt += f" {bot_response}</s> "
14
+ prompt += f"[INST] {message} [/INST]"
15
+ return prompt
16
 
17
+ def generate(
18
+ prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, files=None
19
+ ):
20
  temperature = float(temperature)
21
  if temperature < 1e-2:
22
  temperature = 1e-2
 
31
  seed=42,
32
  )
33
 
 
34
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
35
 
36
+ if files is not None:
37
+ for file in files:
38
+ formatted_prompt += f"\n\nFile content: {file.decode()}"
39
+
40
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
41
  output = ""
42
 
 
45
  yield output
46
  return output
47
 
48
+
49
  additional_inputs=[
50
+ gr.Textbox(
51
+ label="System Prompt",
52
+ max_lines=1,
53
+ interactive=True,
54
+ ),
55
+ gr.Slider(
56
+ label="Temperature",
57
+ value=0.9,
58
+ minimum=0.0,
59
+ maximum=1.0,
60
+ step=0.05,
61
+ interactive=True,
62
+ info="Higher values produce more diverse outputs",
63
+ ),
64
+ gr.Slider(
65
+ label="Max new tokens",
66
+ value=256,
67
+ minimum=0,
68
+ maximum=5120,
69
+ step=64,
70
+ interactive=True,
71
+ info="The maximum numbers of new tokens",
72
+ ),
73
+ gr.Slider(
74
+ label="Top-p (nucleus sampling)",
75
+ value=0.90,
76
+ minimum=0.0,
77
+ maximum=1,
78
+ step=0.05,
79
+ interactive=True,
80
+ info="Higher values sample more low-probability tokens",
81
+ ),
82
+ gr.Slider(
83
+ label="Repetition penalty",
84
+ value=1.2,
85
+ minimum=1.0,
86
+ maximum=2.0,
87
+ step=0.05,
88
+ interactive=True,
89
+ info="Penalize repeated tokens",
90
+ ),
91
+ gr.File(
92
+ label="Upload PDF or Document",
93
+ file_count="multiple",
94
+ file_types=[".pdf", ".doc", ".docx", ".txt"],
95
+ interactive=True,
96
+ )
97
  ]
98
 
 
 
 
 
 
 
 
 
 
99
 
100
+
101
  gr.ChatInterface(
102
  fn=generate,
103
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
 
 
 
 
 
 
 
104
  additional_inputs=additional_inputs,
105
+ title="Synthetic-data-generation-aze",
106
+ concurrency_limit=20,
107
+ ).launch(show_api=False)