crystalkalem's picture
Update app.py
a16fc25 verified
raw
history blame
2.69 kB
import gradio as gr
import torch
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
# Loading the tokenizer and model from Hugging Face's model hub.
if torch.cuda.is_available():
tokenizer = AutoTokenizer.from_pretrained("upstage/SOLAR-10.7B-Instruct-v1.0")
model = AutoModelForCausalLM.from_pretrained("upstage/SOLAR-10.7B-Instruct-v1.0", torch_dtype=torch.float16, device_map="auto")
# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [2] # IDs of tokens where the generation should stop.
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
return True
return False
# Function to generate model predictions.
@spaces.GPU()
def predict(message, history):
stop = StopOnTokens()
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=4096,
do_sample=True,
top_p=0.95,
top_k=50,
temperature=0.6,
repetition_penalty=1.2,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start() # Starting the generation in a separate thread.
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
break
yield partial_message
# Setting up the Gradio chat interface.
gr.ChatInterface(predict,
title="SOLAR 10.7B Instruct v1.0",
description="Warning. All answers are generated and may contain inaccurate information.",
examples=['How do you cook fish?', 'Who is the president of the United States?']
).launch() # Launching the web interface.