dolphin / app.py
nroggendorff's picture
Update app.py
1e38c28 verified
raw
history blame
1.29 kB
import gradio as gr
import os
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_path = "cognitivecomputations/dolphin-2.8-mistral-7b-v02"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_path)
model.config.pad_token_id = model.config.eos_token_id
system_prompt = f"<|im_start|>system\nYou are Santa.<|im_end|>\n"
history = system_prompt
@spaces.GPU(duration=120)
def chat(prompt):
input_text = history + "<|im_start|>user\n" + prompt + "<|im_end|>\n"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
attention_mask = torch.ones_like(input_ids)
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=1024,
num_return_sequences=1,
top_p=0.9,
top_k=50,
num_beams=2,
pad_token_id=model.config.eos_token_id
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
history += "<|im_start|>assistant\n" + response + "<|im_end|>\n"
return response
demo = gr.Interface(
fn=chat,
inputs=gr.Textbox(placeholder="Enter your message here"),
outputs=gr.Textbox(label="Response")
)
if __name__ == "__main__":
demo.launch()