File size: 3,557 Bytes
cd650c7
 
adba430
f5352d5
8115786
fa4d0d9
 
8115786
34421df
a5056fa
adba430
923f75f
f5352d5
8115786
f5352d5
 
 
 
 
 
adba430
fa4d0d9
 
 
 
 
8115786
f5352d5
e798af8
fa4d0d9
d53066f
a5056fa
 
8115786
a5056fa
8115786
d53066f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5056fa
45b3e18
e798af8
a5056fa
 
 
 
d53066f
 
fa4d0d9
34421df
 
 
8115786
 
cd650c7
1fd65af
8115786
1fd65af
 
 
 
8115786
1fd65af
8115786
f5352d5
acf104d
adba430
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from huggingface_hub import InferenceClient
import gradio as gr
import pandas as pd
import re
import random
import csv
import os
import io
import tempfile

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

def extract_sentences_from_excel(file):
    df = pd.read_excel(file)
    sentences = []
    for row in df.values.tolist():
        text = ' '.join(str(x) for x in row)
        new_sentences = re.split(r'(?<=[^.!?])(?=[.!?])', text)
        sentences.extend([s.strip() for s in new_sentences if s.strip()])
    return sentences

def save_to_csv(sentence, output, filename="synthetic_data.csv"):
    with open(filename, mode='a', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow([sentence, output])

def generate(file, temperature, max_new_tokens, top_p, repetition_penalty, num_sentences=10000):
    sentences = extract_sentences_from_excel(file)
    random.shuffle(sentences)  # Shuffle sentences

    with tempfile.NamedTemporaryFile(mode='w', newline='', delete=False, suffix='.csv') as tmp:
        fieldnames = ['Original Sentence', 'Generated Sentence']
        writer = csv.DictWriter(tmp, fieldnames=fieldnames)
        writer.writeheader()

        for sentence in sentences[:num_sentences]:  # Process the first num_sentences sentences
            sentence = sentence.strip()
            if not sentence:
                continue

            generate_kwargs = {
                "temperature": temperature,
                "max_new_tokens": max_new_tokens,
                "top_p": top_p,
                "repetition_penalty": repetition_penalty,
                "do_sample": True,
                "seed": 42,
            }

            try:
                stream = client.text_generation(sentence, **generate_kwargs, stream=True, details=True, return_full_text=False)
                output = ""
                for response in stream:
                    output += response.token.text

                generated_sentences = re.split(r'(?<=[\.\!\?:])[\s\n]+', output)
                generated_sentences = [s.strip() for s in generated_sentences if s.strip() and s != '.']

                for generated_sentence in generated_sentences:
                    writer.writerow({'Original Sentence': sentence, 'Generated Sentence': generated_sentence})

            except Exception as e:
                print(f"Error generating data for sentence '{sentence}': {e}")

        tmp_path = tmp.name

    return tmp_path

gr.Interface(
    fn=generate,
    inputs=[
        gr.File(label="Upload Excel File", file_count="single", file_types=[".xlsx", ".xls"]),
        gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
        gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
        gr.Slider(label="Top-p (nucleus sampling)", value=0.95, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
        gr.Slider(label="Repetition penalty", value=1.0, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repeated tokens"),
        gr.Slider(label="Number of sentences", value=10000, minimum=1, maximum=100000, step=1000, interactive=True, info="The number of sentences to generate"),
    ],
    outputs=gr.File(label="Synthetic Data"),
    title="SDG",
    description="AYE QABIL.",
    allow_flagging="never",
).launch()