|
|
|
|
|
import spaces |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
import gradio as gr |
|
import torch |
|
from transformers.utils import logging |
|
from example_queries import small_query, long_query |
|
|
|
logging.set_verbosity_info() |
|
logger = logging.get_logger("transformers") |
|
|
|
model_name='t5-small' |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) |
|
|
|
ft_model_name="daljeetsingh/sql_ft_t5small_kag" |
|
ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_name, torch_dtype=torch.bfloat16) |
|
|
|
original_model.to('cuda') |
|
ft_model.to('cuda') |
|
|
|
@spaces.GPU |
|
def translate_text(text): |
|
prompt = f"{text}" |
|
inputs = tokenizer(prompt, return_tensors='pt') |
|
inputs = inputs.to('cuda') |
|
|
|
try: |
|
output = tokenizer.decode( |
|
original_model.generate( |
|
inputs["input_ids"], |
|
max_new_tokens=200, |
|
)[0], |
|
skip_special_tokens=True |
|
) |
|
ft_output = tokenizer.decode( |
|
ft_model.generate( |
|
inputs["input_ids"], |
|
max_new_tokens=200, |
|
)[0], |
|
skip_special_tokens=True |
|
) |
|
return [output, ft_output] |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox( |
|
value=small_query, |
|
lines=8, |
|
placeholder="Enter prompt...", |
|
label="Prompt" |
|
) |
|
submit_btn = gr.Button(value="Generate") |
|
with gr.Column(): |
|
orig_output = gr.Textbox(label="OriginalModel", lines=2) |
|
ft_output = gr.Textbox(label="FTModel", lines=8) |
|
|
|
submit_btn.click( |
|
translate_text, inputs=[prompt], outputs=[orig_output, ft_output], api_name=False |
|
) |
|
examples = gr.Examples( |
|
examples=[ |
|
[small_query], |
|
[long_query], |
|
], |
|
inputs=[prompt], |
|
) |
|
|
|
demo.launch(show_api=False, share=True, debug=True) |
|
|
|
|