ramalMr commited on
Commit
664305c
·
verified ·
1 Parent(s): 35fbaa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -24
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  import random
4
  import pandas as pd
5
  from io import BytesIO
6
- import csv
7
  import os
8
  import io
9
  import tempfile
@@ -11,23 +11,23 @@ import re
11
 
12
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
13
 
14
- def extract_text_from_excel(file):
15
  df = pd.read_excel(file)
16
- text = ' '.join(df['Unnamed: 1'].astype(str))
17
- return text
 
 
18
 
19
- def save_to_csv(prompt, sentence, output, filename="synthetic_data.csv"):
20
  with open(filename, mode='a', newline='', encoding='utf-8') as file:
21
  writer = csv.writer(file)
22
- writer.writerow([prompt, sentence, output])
23
 
24
- def generate(file, prompt, temperature, max_new_tokens, top_p, repetition_penalty, num_similar_sentences):
25
- text = extract_text_from_excel(file)
26
- sentences = text.split('.')
27
- random.shuffle(sentences) # Shuffle sentences
28
 
29
  with tempfile.NamedTemporaryFile(mode='w', newline='', delete=False, suffix='.csv') as tmp:
30
- fieldnames = ['Prompt', 'Original Sentence', 'Generated Sentence']
31
  writer = csv.DictWriter(tmp, fieldnames=fieldnames)
32
  writer.writeheader()
33
 
@@ -46,7 +46,7 @@ def generate(file, prompt, temperature, max_new_tokens, top_p, repetition_penalt
46
  }
47
 
48
  try:
49
- stream = client.text_generation(prompt + sentence, **generate_kwargs, stream=True, details=True, return_full_text=False)
50
  output = ""
51
  for response in stream:
52
  output += response.token.text
@@ -54,34 +54,27 @@ def generate(file, prompt, temperature, max_new_tokens, top_p, repetition_penalt
54
  generated_sentences = re.split(r'(?<=[\.\!\?:])[\s\n]+', output)
55
  generated_sentences = [s.strip() for s in generated_sentences if s.strip() and s != '.']
56
 
57
- for _ in range(num_similar_sentences):
58
- if not generated_sentences:
59
- break
60
- generated_sentence = generated_sentences.pop(random.randrange(len(generated_sentences)))
61
- writer.writerow({'Prompt': prompt, 'Original Sentence': sentence, 'Generated Sentence': generated_sentence})
62
 
63
  except Exception as e:
64
  print(f"Error generating data for sentence '{sentence}': {e}")
65
 
66
  tmp_path = tmp.name
67
 
68
- return tmp_path, output
69
 
70
  gr.Interface(
71
  fn=generate,
72
  inputs=[
73
  gr.File(label="Upload Excel File", file_count="single", file_types=[".xlsx"]),
74
- gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
75
  gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
76
  gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
77
  gr.Slider(label="Top-p (nucleus sampling)", value=0.95, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
78
  gr.Slider(label="Repetition penalty", value=1.0, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repeated tokens"),
79
- gr.Slider(label="Number of similar sentences", value=10, minimum=1, maximum=20, step=1, interactive=True, info="Number of similar sentences to generate for each original sentence"),
80
- ],
81
- outputs=[
82
- gr.File(label="Synthetic Data"),
83
- gr.Textbox(label="Generated Output")
84
  ],
 
85
  title="SDG",
86
  description="AYE QABIL.",
87
  allow_flagging="never",
 
3
  import random
4
  import pandas as pd
5
  from io import BytesIO
6
+ import csv
7
  import os
8
  import io
9
  import tempfile
 
11
 
12
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
13
 
14
+ def extract_sentences_from_excel(file):
15
  df = pd.read_excel(file)
16
+ text = ' '.join(df['Column_Name'].astype(str))
17
+ sentences = text.split('.')
18
+ sentences = [s.strip() for s in sentences if s.strip()]
19
+ return sentences
20
 
21
+ def save_to_csv(sentence, output, filename="synthetic_data.csv"):
22
  with open(filename, mode='a', newline='', encoding='utf-8') as file:
23
  writer = csv.writer(file)
24
+ writer.writerow([sentence, output])
25
 
26
+ def generate(file, prompt, temperature, max_new_tokens, top_p, repetition_penalty):
27
+ sentences = extract_sentences_from_excel(file)
 
 
28
 
29
  with tempfile.NamedTemporaryFile(mode='w', newline='', delete=False, suffix='.csv') as tmp:
30
+ fieldnames = ['Original Sentence', 'Generated Sentence']
31
  writer = csv.DictWriter(tmp, fieldnames=fieldnames)
32
  writer.writeheader()
33
 
 
46
  }
47
 
48
  try:
49
+ stream = client.text_generation(f"{prompt} {sentence}", **generate_kwargs, stream=True, details=True, return_full_text=False)
50
  output = ""
51
  for response in stream:
52
  output += response.token.text
 
54
  generated_sentences = re.split(r'(?<=[\.\!\?:])[\s\n]+', output)
55
  generated_sentences = [s.strip() for s in generated_sentences if s.strip() and s != '.']
56
 
57
+ for generated_sentence in generated_sentences:
58
+ writer.writerow({'Original Sentence': sentence, 'Generated Sentence': generated_sentence})
 
 
 
59
 
60
  except Exception as e:
61
  print(f"Error generating data for sentence '{sentence}': {e}")
62
 
63
  tmp_path = tmp.name
64
 
65
+ return tmp_path
66
 
67
  gr.Interface(
68
  fn=generate,
69
  inputs=[
70
  gr.File(label="Upload Excel File", file_count="single", file_types=[".xlsx"]),
71
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt here"),
72
  gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
73
  gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
74
  gr.Slider(label="Top-p (nucleus sampling)", value=0.95, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
75
  gr.Slider(label="Repetition penalty", value=1.0, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repeated tokens"),
 
 
 
 
 
76
  ],
77
+ outputs=gr.File(label="Synthetic Data "),
78
  title="SDG",
79
  description="AYE QABIL.",
80
  allow_flagging="never",