ramalMr commited on
Commit
9cbb806
·
verified ·
1 Parent(s): b6d3f79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -16
app.py CHANGED
@@ -4,7 +4,7 @@ import re
4
  import random
5
  import csv
6
  import tempfile
7
- import gradio as gr
8
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
9
 
10
  def extract_sentences_from_excel(file):
@@ -16,16 +16,16 @@ def extract_sentences_from_excel(file):
16
  sentences.extend([s.strip() for s in new_sentences if s.strip()])
17
  return sentences
18
 
19
- def generate_text(file, temperature, max_new_tokens, top_p, repetition_penalty, num_sentences=10000):
20
  sentences = extract_sentences_from_excel(file)
21
  random.shuffle(sentences)
22
 
23
  with tempfile.NamedTemporaryFile(mode='w', newline='', delete=False, suffix='.csv') as tmp:
24
- fieldnames = ['Original Sentence', 'Generated Sentence']
25
  writer = csv.DictWriter(tmp, fieldnames=fieldnames)
26
  writer.writeheader()
27
 
28
- for sentence in sentences[:num_sentences]:
29
  sentence = sentence.strip()
30
  if not sentence:
31
  continue
@@ -40,16 +40,14 @@ def generate_text(file, temperature, max_new_tokens, top_p, repetition_penalty,
40
  }
41
 
42
  try:
43
- stream = client.text_generation(sentence, **generate_kwargs, stream=True, details=True, return_full_text=False)
44
- output = ""
45
- for response in stream:
46
- output += response.token.text
47
 
48
- generated_sentences = re.split(r'(?<=[\.\!\?:])[\s\n]+', output)
49
  generated_sentences = [s.strip() for s in generated_sentences if s.strip() and s != '.']
50
 
51
  for generated_sentence in generated_sentences:
52
- writer.writerow({'Original Sentence': sentence, 'Generated Sentence': generated_sentence})
53
 
54
  except Exception as e:
55
  print(f"Error generating data for sentence '{sentence}': {e}")
@@ -57,18 +55,18 @@ def generate_text(file, temperature, max_new_tokens, top_p, repetition_penalty,
57
  tmp_path = tmp.name
58
 
59
  return tmp_path
 
60
  gr.Interface(
61
- fn=generate_text,
62
  inputs=[
63
  gr.File(label="Upload Excel File", file_count="single", file_types=[".xlsx", ".xls"]),
64
  gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
65
  gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
66
  gr.Slider(label="Top-p (nucleus sampling)", value=0.95, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
67
  gr.Slider(label="Repetition penalty", value=1.0, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repeated tokens"),
68
- gr.Slider(label="Number of sentences", value=10000, minimum=1, maximum=100000, step=1000, interactive=True, info="The number of sentences to generate"),
69
  ],
70
- outputs=gr.File(label="Generated CSV"),
71
- title="Text Generation from Excel",
72
- description="Generate text from sentences in an Excel file and save it to a CSV file.",
73
  allow_flagging="never",
74
- ).launch()
 
4
  import random
5
  import csv
6
  import tempfile
7
+ import gradio as gr
8
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
9
 
10
  def extract_sentences_from_excel(file):
 
16
  sentences.extend([s.strip() for s in new_sentences if s.strip()])
17
  return sentences
18
 
19
+ def generate_synthetic_data(file, temperature, max_new_tokens, top_p, repetition_penalty):
20
  sentences = extract_sentences_from_excel(file)
21
  random.shuffle(sentences)
22
 
23
  with tempfile.NamedTemporaryFile(mode='w', newline='', delete=False, suffix='.csv') as tmp:
24
+ fieldnames = ['Original Sentence', 'Synthetic Data']
25
  writer = csv.DictWriter(tmp, fieldnames=fieldnames)
26
  writer.writeheader()
27
 
28
+ for sentence in sentences:
29
  sentence = sentence.strip()
30
  if not sentence:
31
  continue
 
40
  }
41
 
42
  try:
43
+ output = client.generate(sentence, **generate_kwargs, return_full_text=True)
44
+ generated_data = output.text.strip()
 
 
45
 
46
+ generated_sentences = re.split(r'(?<=[\.\!\?:])[\s\n]+', generated_data)
47
  generated_sentences = [s.strip() for s in generated_sentences if s.strip() and s != '.']
48
 
49
  for generated_sentence in generated_sentences:
50
+ writer.writerow({'Original Sentence': sentence, 'Synthetic Data': generated_sentence})
51
 
52
  except Exception as e:
53
  print(f"Error generating data for sentence '{sentence}': {e}")
 
55
  tmp_path = tmp.name
56
 
57
  return tmp_path
58
+
59
  gr.Interface(
60
+ fn=generate_synthetic_data,
61
  inputs=[
62
  gr.File(label="Upload Excel File", file_count="single", file_types=[".xlsx", ".xls"]),
63
  gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
64
  gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
65
  gr.Slider(label="Top-p (nucleus sampling)", value=0.95, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
66
  gr.Slider(label="Repetition penalty", value=1.0, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repeated tokens"),
 
67
  ],
68
+ outputs=gr.File(label="Synthetic Data CSV"),
69
+ title="Synthetic Data Generation",
70
+ description="Generate synthetic data from sentences in an Excel file and save it to a CSV file.",
71
  allow_flagging="never",
72
+ ).launch()