Bunny / app.py
Isaachh's picture
temporarily switch
71e86f7
import torch
import transformers
import warnings
import time
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from PIL import Image
from threading import Thread
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings("ignore")
device = "cuda" # or cpu
torch.set_default_device(device)
model_name = "BAAI/Bunny-v1_1-Llama-3-8B-V"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # float32 for cpu
device_map="auto",
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True)
@spaces.GPU
def bot_streaming(message, history):
print(message)
if message["files"]:
# message["files"][-1] is a Dict or just a string
if type(message["files"][-1]) == dict:
image_file = message["files"][-1]["path"]
else:
image_file = message["files"][-1]
else:
image_file = None
# if there's no image uploaded for this turn, look for images in the past turns
# kept inside tuples, take the last one
for hist in history:
if type(hist[0]) == tuple:
image_file = hist[0][0]
prompt = message["text"]
if image_file is None:
text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:"
input_ids = torch.tensor(tokenizer(text).input_ids, dtype=torch.long).unsqueeze(0).to(device)
else:
text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split("<image>")]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device)
if image_file is not None:
image = Image.open(image_file)
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype, device=device)
else:
image_tensor = None
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
thread = Thread(target=model.generate, kwargs=dict(
inputs=input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
top_p=0.7,
max_new_tokens=512,
streamer=streamer,
use_cache=True,
repetition_penalty=1.08
))
thread.start()
buffer = ""
time.sleep(0.5)
for new_text in streamer:
if "<|end_of_text|>" in new_text:
new_text = new_text.split("<|end_of_text|>")[0]
buffer += new_text
# generated_text_without_prompt = buffer[len(text_prompt):]
generated_text_without_prompt = buffer
# print(generated_text_without_prompt)
time.sleep(0.06)
# print(f"new_text: {generated_text_without_prompt}")
yield generated_text_without_prompt
title_markdown = ("""
# 🐰 Bunny: A family of lightweight multimodal models
[πŸ“– [Technical report](https://arxiv.org/abs/2402.11530)] | [🏠 [Code](https://github.com/BAAI-DCAI/Bunny)] | [πŸ€— [Bunny-v1.1-Llama-3-8B-V](https://huggingface.co/BAAI/Bunny-v1_1-Llama-3-8B-V)] | [πŸ€— [Bunny-v1.1-4B](https://huggingface.co/BAAI/Bunny-v1_1-4B)] | [πŸ€— [Bunny-v1.0-3B](https://huggingface.co/BAAI/Bunny-v1_0-3B)]
""")
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Bunny-v1.1-Llama-3-8B-V",
avatar_images=[f"./assets/user.png", f"./assets/icon.jpg"],
height=550
)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="Enter message or upload file...",
show_label=False
)
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(title_markdown)
gr.ChatInterface(
fn=bot_streaming,
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot
)
gr.Examples(examples=[{"text": "What is the astronaut holding in his hand?", "files": ["./assets/example_1.png"]},
{"text": "Why is the image funny?", "files": ["./assets/example_2.png"]}], inputs=chat_input)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)