phi-3-mini / app.py
Walmart-the-bag's picture
Update app.py
91d1248 verified
raw
history blame
1.89 kB
import gradio as gr
from transformers import TextIteratorStreamer
from threading import Thread
from transformers import StoppingCriteria, StoppingCriteriaList
import torch
import spaces
model_name = "microsoft/Phi-3-mini-128k-instruct"
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='cuda', trust_remote_code=True)
model = model.to('cuda:0')
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [29, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@spaces.GPU(duration=180)
def predict(message, history):
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
messages = "".join(["".join(["<|end|>\n<|user|>\n"+item[0], "<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=4096,
do_sample=True,
top_p=0.9,
top_k=40,
temperature=0.9,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
demo = gr.ChatInterface(fn=predict, examples=["What is life?"], title="AI", fill_height=True)
demo.launch(show_api=False)