dolphin / app.py
nroggendorff's picture
Update app.py
482f6f1 verified
raw
history blame
1.19 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
torch.set_default_device("cuda")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model_id = "cognitivecomputations/dolphin-2.9.3-mistral-7B-32k"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
def predict(input_text, history):
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
chat.append({"role": "user", "content": input_text})
conv = tokenizer.apply_chat_template(chat, tokenize=False)
inputs = tokenizer(conv, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=512)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return generated_text.split("<|assistant|>")[-1]
gr.ChatInterface(predict, theme="soft").launch()