JoPmt's picture
Update app.py
293b850 verified
raw
history blame
822 Bytes
from accelerate import Accelerator
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import os, random, gc
import torch
accelerator=Accelerator(cpu=True)
mdl = "bigcode/starcoder2-3b"
tokenizer = AutoTokenizer.from_pretrained(mdl)
model = accelerator.prepare(AutoModelForCausalLM.from_pretrained(mdl, torch_dtype=torch.bfloat16))
def plex(ynputs):
onputs = tokenizer.encode(""+ynputs+"", return_tensors="pt").to("cpu")
iutputs = model.generate(onputs, max_new_tokens=350)
return tokenizer.decode(iutputs[0])
with gr.Blocks() as iface:
zutput=gr.Textbox(lines=5)
znput=gr.Textbox(lines=2)
btn=gr.Button("ASK")
btn.click(fn=plex, inputs=znput, outputs=zutput)
iface.queue(max_size=1,api_open=False)
iface.launch(max_threads=20,inline=False,show_api=False)