data_gen / app.py
ramalMr's picture
Update app.py
4e49a48 verified
raw
history blame
3.8 kB
from huggingface_hub import InferenceClient
import gradio as gr
import random
import pandas as pd
from io import BytesIO
import csv
import os
import io
import tempfile
import re
from transformers import MarianMTModel, MarianTokenizer
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
def extract_text_from_excel(file):
df = pd.read_excel(file)
text = ' '.join(df['Unnamed: 1'].astype(str))
return text
def save_to_csv(sentence, output, filename="synthetic_data.csv"):
with open(filename, mode='a', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerow([sentence, output])
def translate_english_to_azerbaijani(text):
model_name = 'Helsinki-NLP/opus-mt-en-az'
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer.prepare_translation_batch([text]))
translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
return translated_text
def generate(file, temperature, max_new_tokens, top_p, repetition_penalty):
text = extract_text_from_excel(file)
sentences = text.split('.')
random.shuffle(sentences) # Shuffle sentences
with tempfile.NamedTemporaryFile(mode='w', newline='', delete=False, suffix='.csv') as tmp:
fieldnames = ['Original Sentence', 'Generated Sentence']
writer = csv.DictWriter(tmp, fieldnames=fieldnames)
writer.writeheader()
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
generate_kwargs = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"do_sample": True,
"seed": 42,
}
try:
stream = client.text_generation(sentence, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
generated_sentences = re.split(r'(?<=[\.\!\?:])[\s\n]+', output)
generated_sentences = [s.strip() for s in generated_sentences if s.strip() and s != '.']
for generated_sentence in generated_sentences:
translated_original = translate_english_to_azerbaijani(sentence)
translated_generated = translate_english_to_azerbaijani(generated_sentence)
writer.writerow({'Original Sentence': translated_original, 'Generated Sentence': translated_generated})
except Exception as e:
print(f"Error generating data for sentence '{sentence}': {e}")
tmp_path = tmp.name
return tmp_path
gr.Interface(
fn=generate,
inputs=[
gr.File(label="Upload Excel File", file_count="single", file_types=[".xlsx"]),
gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
gr.Slider(label="Top-p (nucleus sampling)", value=0.95, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
gr.Slider(label="Repetition penalty", value=1.0, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repeated tokens"),
],
outputs=gr.File(label="Synthetic Data "),
title="SDG",
description="AYE QABIL.",
allow_flagging="never",
).launch()