Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import transformers | |
import warnings | |
import time | |
import spaces | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from PIL import Image | |
from threading import Thread | |
transformers.logging.set_verbosity_error() | |
transformers.logging.disable_progress_bar() | |
warnings.filterwarnings("ignore") | |
device = "cuda" # or cpu | |
torch.set_default_device(device) | |
model_name = "BAAI/Bunny-v1_1-Llama-3-8B-V" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, # float32 for cpu | |
device_map="auto", | |
trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True) | |
def bot_streaming(message, history): | |
print(message) | |
if message["files"]: | |
# message["files"][-1] is a Dict or just a string | |
if type(message["files"][-1]) == dict: | |
image_file = message["files"][-1]["path"] | |
else: | |
image_file = message["files"][-1] | |
else: | |
image_file = None | |
# if there's no image uploaded for this turn, look for images in the past turns | |
# kept inside tuples, take the last one | |
for hist in history: | |
if type(hist[0]) == tuple: | |
image_file = hist[0][0] | |
prompt = message["text"] | |
if image_file is None: | |
text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:" | |
input_ids = torch.tensor(tokenizer(text).input_ids, dtype=torch.long).unsqueeze(0).to(device) | |
else: | |
text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:" | |
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split("<image>")] | |
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device) | |
if image_file is not None: | |
image = Image.open(image_file) | |
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype, device=device) | |
else: | |
image_tensor = None | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) | |
thread = Thread(target=model.generate, kwargs=dict( | |
inputs=input_ids, | |
images=image_tensor, | |
do_sample=True, | |
temperature=0.2, | |
top_p=0.7, | |
max_new_tokens=512, | |
streamer=streamer, | |
use_cache=True, | |
repetition_penalty=1.08 | |
)) | |
thread.start() | |
buffer = "" | |
time.sleep(0.5) | |
for new_text in streamer: | |
if "<|end_of_text|>" in new_text: | |
new_text = new_text.split("<|end_of_text|>")[0] | |
buffer += new_text | |
# generated_text_without_prompt = buffer[len(text_prompt):] | |
generated_text_without_prompt = buffer | |
# print(generated_text_without_prompt) | |
time.sleep(0.06) | |
# print(f"new_text: {generated_text_without_prompt}") | |
yield generated_text_without_prompt | |
title_markdown = (""" | |
# π° Bunny: A family of lightweight multimodal models | |
[π [Technical report](https://arxiv.org/abs/2402.11530)] | [π [Code](https://github.com/BAAI-DCAI/Bunny)] | [π€ [Bunny-v1.1-Llama-3-8B-V](https://huggingface.co/BAAI/Bunny-v1_1-Llama-3-8B-V)] | [π€ [Bunny-v1.1-4B](https://huggingface.co/BAAI/Bunny-v1_1-4B)] | [π€ [Bunny-v1.0-3B](https://huggingface.co/BAAI/Bunny-v1_0-3B)] | |
""") | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label="Bunny-v1.1-Llama-3-8B-V", | |
avatar_images=[f"./assets/user.png", f"./assets/icon.jpg"], | |
height=550 | |
) | |
chat_input = gr.MultimodalTextbox( | |
interactive=True, | |
file_types=["image"], | |
placeholder="Enter message or upload file...", | |
show_label=False | |
) | |
with gr.Blocks(fill_height=True) as demo: | |
gr.Markdown(title_markdown) | |
gr.ChatInterface( | |
fn=bot_streaming, | |
stop_btn="Stop Generation", | |
multimodal=True, | |
textbox=chat_input, | |
chatbot=chatbot | |
) | |
gr.Examples(examples=[{"text": "What is the astronaut holding in his hand?", "files": ["./assets/example_1.png"]}, | |
{"text": "Why is the image funny?", "files": ["./assets/example_2.png"]}], inputs=chat_input) | |
demo.queue(api_open=False) | |
demo.launch(show_api=False, share=False) |