gpt / app.py
wannaphong's picture
Update app.py
afe040b verified
raw
history blame
1.49 kB
import torch
import transformers
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from threading import Thread
from transformers import TextIteratorStreamer
model_name = "numfa/numfa_v2-3b"
model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens = True)
def generate_text(prompt, max_length, top_p, top_k):
inputs = tokenizer([prompt], return_tensors="pt")
generate_kwargs = dict(
inputs,
max_length=int(max_length),top_p=float(top_p), do_sample=True, top_k=int(top_k), streamer=streamer
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
generated_text=[]
for text in streamer:
generated_text.append(text)
yield "".join(generated_text)
description = """
# Deploy your first ML app using Gradio
"""
inputs = [
gr.Textbox(label="Prompt text"),
gr.Textbox(label="max-lenth generation", value=100),
gr.Slider(0.0, 1.0, label="top-p value", value=0.95),
gr.Textbox(label="top-k", value=50,),
]
outputs = [gr.Textbox(label="Generated Text")]
demo = gr.Interface(fn=generate_text, inputs=inputs, outputs=outputs, allow_flagging=False, description=description)
demo.launch()