Spaces:
Paused
Paused
File size: 2,032 Bytes
a896396 0d15563 f9588c9 8339753 0d15563 5c3c196 fe11a00 310f92d 0d15563 f086864 0d15563 8339753 0d15563 8339753 0d15563 fe11a00 280b665 fe11a00 0d15563 ee70b76 0d15563 eab0d43 fe11a00 ee70b76 0d15563 f9588c9 310f92d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [50256, 50295]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@spaces.GPU(duration=480)
def predict(message, history):
torch.set_default_device("cuda")
tokenizer = AutoTokenizer.from_pretrained(
"cognitivecomputations/dolphin-2.8-mistral-7b-v02",
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"cognitivecomputations/dolphin-2.8-mistral-7b-v02",
torch_dtype="auto",
load_in_4bit=True,
trust_remote_code=True
)
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
system_prompt = "<|im_start|>system\nYou are Dolphin, a helpful AI assistant.<|im_end|>"
messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids,
streamer=streamer,
max_new_tokens=256,
do_sample=True,
top_p=0.95,
top_k=50,
temperature=0.7,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '<|im_end|>' in partial_message:
break
yield partial_message
gr.ChatInterface(predict).launch() |