File size: 3,625 Bytes
838b223
cd650c7
9cbb806
418de0b
 
838b223
d435c8a
97425d1
838b223
18cb91d
 
09b14bf
adba430
923f75f
664305c
8115786
08d8e2d
664305c
d435c8a
664305c
adba430
290168b
 
838b223
d0ee1ab
 
 
 
 
290168b
d0ee1ab
 
 
 
 
 
 
 
97425d1
664305c
 
838b223
fa4d0d9
838b223
18cb91d
 
 
 
a5056fa
18cb91d
 
 
 
 
 
 
 
d53066f
18cb91d
cf4b1fe
 
 
 
 
b54c869
7b026a2
18cb91d
 
fa4d0d9
838b223
18cb91d
34421df
664305c
9cbb806
838b223
418de0b
1fd65af
418de0b
664305c
1fd65af
 
 
 
b3b73c4
664305c
418de0b
 
adba430
09b14bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json
from huggingface_hub import InferenceClient
import gradio as gr
import random
import pandas as pd
from io import BytesIO
import csv
import os
import io
import tempfile
import re

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

def extract_sentences_from_excel(file):
    df = pd.read_excel(file)
    text = ' '.join(df['metn'].astype(str))
    sentences = text.split('.')
    sentences = [s.strip() for s in sentences if s.strip() and s.strip() != 'nan']
    return sentences

import re

def save_to_json(data, filename="synthetic_data.json"):
    with open(filename, mode='w', encoding='utf-8') as file:
        json_data = []
        for item in data:
            generated_sentences = []
            confidence_scores = []
            for match in re.finditer(r"{'generated_sentence': '(.+?)', 'confidence_score': ([\d\.]+)}", item['generated_data']):
                generated_sentences.append(match.group(1))
                confidence_scores.append(float(match.group(2)))
            json_data.append({
                'original_sentence': item['original_sentence'],
                'generated_sentences': generated_sentences,
                'confidence_scores': confidence_scores
            })
        json.dump(json_data, file, indent=4, ensure_ascii=False)

def generate(file, prompt, temperature, max_new_tokens, top_p, repetition_penalty):
    sentences = extract_sentences_from_excel(file)
    data = []

    with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as tmp:
        for sentence in sentences:
            sentence = sentence.strip()
            if not sentence:
                continue

            generate_kwargs = {
                "temperature": temperature,
                "max_new_tokens": max_new_tokens,
                "top_p": top_p,
                "repetition_penalty": repetition_penalty,
                "do_sample": True,
                "seed": 42,
            }

            try:
                stream = client.text_generation(f"{prompt} Output the response in the following JSON format: {{'generated_sentence': 'The generated sentence text', 'confidence_score': 0.9}} {sentence}", **generate_kwargs, stream=True, details=True, return_full_text=False)
                output = ""
                for response in stream:
                    output += response.token.text

                data.append({"original_sentence": sentence, "generated_data": output})

            except Exception as e:
                print(f"Error generating data for sentence '{sentence}': {e}")

        save_to_json(data, tmp.name)
        tmp_path = tmp.name

    return tmp_path

gr.Interface(
    fn=generate,
    inputs=[
        gr.File(label="Upload Excel File", file_count="single", file_types=[".xlsx"]),
        gr.Textbox(label="Prompt", placeholder="Enter your prompt here"),
        gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
        gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
        gr.Slider(label="Top-p (nucleus sampling)", value=0.95, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
        gr.Slider(label="Repetition penalty", value=1.0, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repeated tokens"),
    ],
    outputs=gr.File(label="Synthetic Data "),
    title="SDG",
    description="AYE QABIL.",
    allow_flagging="never",
).launch()