ramalMr commited on
Commit
112c42e
·
verified ·
1 Parent(s): 585a4f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -73
app.py CHANGED
@@ -1,104 +1,100 @@
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
  import re
4
- from nltk.tokenize import sent_tokenize
5
 
6
  client = InferenceClient(
7
  "mistralai/Mixtral-8x7B-Instruct-v0.1"
8
  )
9
 
10
- def tokenize_sentences(file_content):
11
- sentences = sent_tokenize(file_content.decode())
12
- return sentences
13
-
14
- def generate_synthetic_data(prompt, sentences, data_size, toxicity_level, use_emoji):
15
- synthetic_data = []
16
- for sentence in sentences[:data_size]:
17
- # Apply the prompt instructions to generate synthetic data from the sentence
18
- synthetic_sentence = f"{prompt}: {sentence}"
19
-
20
- # Adjust toxicity level
21
- if toxicity_level == "High":
22
- synthetic_sentence = add_toxic_content(synthetic_sentence)
23
- elif toxicity_level == "Low":
24
- synthetic_sentence = remove_toxic_content(synthetic_sentence)
25
-
26
- # Add or remove emoji
27
- if use_emoji:
28
- synthetic_sentence = add_emojis(synthetic_sentence)
29
- else:
30
- synthetic_sentence = remove_emojis(synthetic_sentence)
31
-
32
- synthetic_data.append(synthetic_sentence)
33
- return "\n".join(synthetic_data)
34
 
35
- def add_toxic_content(text):
36
- # Add code to make the text more toxic
37
- return text
 
 
 
 
38
 
39
- def remove_toxic_content(text):
40
- # Add code to remove toxic content from the text
41
- return text
 
 
 
 
 
42
 
43
- def add_emojis(text):
44
- # Add code to add emojis to the text
45
- return text
46
 
47
- def remove_emojis(text):
48
- # Add code to remove emojis from the text
49
- return text
 
50
 
51
- def generate(prompt, max_data_size=100, toxicity_level="Neutral", use_emoji=False, files=None):
52
- if files is not None:
53
- file_contents = [file.decode() for file in files]
54
- sentences = []
55
- for content in file_contents:
56
- sentences.extend(tokenize_sentences(content))
57
- synthetic_data = generate_synthetic_data(prompt, sentences, max_data_size, toxicity_level, use_emoji)
58
- return synthetic_data
59
- else:
60
- return "Please upload a file to generate synthetic data."
61
 
62
  additional_inputs=[
63
  gr.Textbox(
64
- label="Prompt for Synthetic Data Generation",
65
  max_lines=1,
66
  interactive=True,
67
  ),
68
  gr.Slider(
69
- label="Max Data Size",
70
- value=100,
71
- minimum=10,
72
- maximum=1000,
73
- step=10,
74
  interactive=True,
75
- info="The maximum number of sentences to include in the synthetic data",
76
  ),
77
- gr.Radio(
78
- label="Toxicity Level",
79
- choices=["High", "Low", "Neutral"],
80
- value="Neutral",
 
 
81
  interactive=True,
82
- info="Adjust the toxicity level of the synthetic data",
83
  ),
84
- gr.Checkbox(
85
- label="Use Emoji",
86
- value=False,
 
 
 
87
  interactive=True,
88
- info="Add or remove emojis in the synthetic data",
89
  ),
90
- gr.File(
91
- label="Upload PDF or Document",
92
- file_count="multiple",
93
- file_types=[".pdf", ".doc", ".docx", ".txt"],
 
 
94
  interactive=True,
95
- )
 
 
96
  ]
97
 
98
- gr.Interface(
99
  fn=generate,
100
- inputs=additional_inputs,
101
- outputs="text",
102
  title="Synthetic-data-generation-aze",
103
- description="Generate synthetic data from uploaded files based on a given prompt and customization options.",
104
- ).launch(show_api=False)
 
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
  import re
 
4
 
5
  client = InferenceClient(
6
  "mistralai/Mixtral-8x7B-Instruct-v0.1"
7
  )
8
 
9
+ def format_prompt(message, history):
10
+ prompt = "<s>"
11
+ for user_prompt, bot_response in history:
12
+ prompt += f"[INST] {user_prompt} [/INST]"
13
+ prompt += f" {bot_response}</s> "
14
+ prompt += f"[INST] {message} [/INST]"
15
+ return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def generate(
18
+ prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
+ ):
20
+ temperature = float(temperature)
21
+ if temperature < 1e-2:
22
+ temperature = 1e-2
23
+ top_p = float(top_p)
24
 
25
+ generate_kwargs = dict(
26
+ temperature=temperature,
27
+ max_new_tokens=max_new_tokens,
28
+ top_p=top_p,
29
+ repetition_penalty=repetition_penalty,
30
+ do_sample=True,
31
+ seed=42,
32
+ )
33
 
34
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
35
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
36
+ output = ""
37
 
38
+ for response in stream:
39
+ output += response.token.text
40
+ yield output
41
+ return output
42
 
43
+ def process_file(file):
44
+ text = file.decode("utf-8")
45
+ sentences = re.split(r'[.!?]+', text)
46
+ sentences = [s.strip() for s in sentences if s.strip()]
47
+ return sentences
 
 
 
 
 
48
 
49
  additional_inputs=[
50
  gr.Textbox(
51
+ label="System Prompt",
52
  max_lines=1,
53
  interactive=True,
54
  ),
55
  gr.Slider(
56
+ label="Temperature",
57
+ value=0.9,
58
+ minimum=0.0,
59
+ maximum=1.0,
60
+ step=0.05,
61
  interactive=True,
62
+ info="Higher values produce more diverse outputs",
63
  ),
64
+ gr.Slider(
65
+ label="Max new tokens",
66
+ value=256,
67
+ minimum=0,
68
+ maximum=5120,
69
+ step=64,
70
  interactive=True,
71
+ info="The maximum numbers of new tokens",
72
  ),
73
+ gr.Slider(
74
+ label="Top-p (nucleus sampling)",
75
+ value=0.90,
76
+ minimum=0.0,
77
+ maximum=1,
78
+ step=0.05,
79
  interactive=True,
80
+ info="Higher values sample more low-probability tokens",
81
  ),
82
+ gr.Slider(
83
+ label="Repetition penalty",
84
+ value=1.2,
85
+ minimum=1.0,
86
+ maximum=2.0,
87
+ step=0.05,
88
  interactive=True,
89
+ info="Penalize repeated tokens",
90
+ ),
91
+ gr.File(label="Upload File", file_count="single"),
92
  ]
93
 
94
+ gr.ChatInterface(
95
  fn=generate,
96
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
97
+ additional_inputs=additional_inputs,
98
  title="Synthetic-data-generation-aze",
99
+ concurrency_limit=20,
100
+ ).launch(show_api=False)