File size: 3,410 Bytes
cd650c7
 
a26857e
adba430
 
438b552
fa4d0d9
 
e06c59d
34421df
a5056fa
 
adba430
923f75f
adba430
 
 
 
 
 
 
fa4d0d9
 
 
 
 
34421df
a5056fa
fa4d0d9
 
 
e798af8
fa4d0d9
b4637e1
d53066f
a5056fa
 
45b3e18
a5056fa
d53066f
 
 
 
 
 
 
 
 
 
 
 
b4637e1
d53066f
 
 
 
 
 
 
a5056fa
b4637e1
45b3e18
e798af8
a5056fa
b4637e1
a5056fa
 
 
d53066f
 
fa4d0d9
34421df
 
 
59ef8d0
cd650c7
1fd65af
 
 
 
 
 
 
acf104d
 
 
adba430
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from huggingface_hub import InferenceClient
import gradio as gr
import PyPDF2
import random
import pandas as pd
from io import BytesIO 
import csv
import os
import io 
import tempfile
import re

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

def extract_text_from_pdf(file):
    pdf_reader = PyPDF2.PdfReader(file)
    text = ""
    for page in range(len(pdf_reader.pages)):
        text += pdf_reader.pages[page].extract_text()
    return text

def save_to_csv(sentence, output, filename="synthetic_data.csv"):
    with open(filename, mode='a', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow([sentence, output])



def generate(file, temperature, max_new_tokens, top_p, repetition_penalty):
    text = extract_text_from_pdf(file)
    sentences = text.split('.')
    random.shuffle(sentences)  # Shuffle sentences

    
    with tempfile.NamedTemporaryFile(mode='w', newline='', delete=False, suffix='.csv') as tmp:
        fieldnames = ['Original Sentence', 'Generated Sentence']
        writer = csv.DictWriter(tmp, fieldnames=fieldnames)
        writer.writeheader()  

        for sentence in sentences:
            sentence = sentence.strip()
            if not sentence:
                continue

            generate_kwargs = {
                "temperature": temperature,
                "max_new_tokens": max_new_tokens,
                "top_p": top_p,
                "repetition_penalty": repetition_penalty,
                "do_sample": True,
                "seed": 42,
                "use_cache": False  
            }

            try:
                stream = client.text_generation(sentence, **generate_kwargs, stream=True, details=True, return_full_text=False)
                output = ""
                for response in stream:
                    output += response.token.text

                # Modelden gelen yanıtı cümlelere ayır
                generated_sentences = re.split(r'(?<=[\.\!\?:])[\s\n]+', output)
                generated_sentences = [s.strip() for s in generated_sentences if s.strip() and s != '.']

                # Her cümleyi ayrı bir satır olarak CSV'ye yaz
                for generated_sentence in generated_sentences:
                    writer.writerow({'Original Sentence': sentence, 'Generated Sentence': generated_sentence})

            except Exception as e:
                print(f"Error generating data for sentence '{sentence}': {e}")

        tmp_path = tmp.name

    return tmp_path
gr.Interface( 
    fn=generate,
    inputs=[
        gr.File(label="Upload PDF File", file_count="single", file_types=[".pdf"]),
        gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
        gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
        gr.Slider(label="Top-p (nucleus sampling)", value=0.95, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
        gr.Slider(label="Repetition penalty", value=1.0, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repeated tokens"),
    ],
    outputs=gr.File(label="Synthetic Data "),
    title="SDG",
    description="AYE QABIL.",
    allow_flagging="never",
).launch()