|
import gradio as gr |
|
import spaces |
|
import time |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
from transformers.image_utils import load_image |
|
from typing import List |
|
processor = AutoProcessor.from_pretrained("TIGER-Lab/Mantis-8B-Idefics2") |
|
model = AutoModelForVision2Seq.from_pretrained("TIGER-Lab/Mantis-8B-Idefics2", torch_dtype=torch.bfloat16) |
|
|
|
@spaces.GPU |
|
def generate_stream(text:str, images:List[Image.Image], history: List[dict], **kwargs): |
|
global processor, model |
|
model.to("cuda") |
|
if not images: |
|
images = None |
|
|
|
prompt = processor.apply_chat_template(history, add_generation_prompt=True) |
|
print("Prompt: ") |
|
print(prompt) |
|
print("Images: ") |
|
print(images) |
|
inputs = processor(text=prompt, images=images, return_tensors="pt") |
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
from transformers import TextIteratorStreamer |
|
from threading import Thread |
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
|
kwargs["streamer"] = streamer |
|
inputs.update(kwargs) |
|
thread = Thread(target=model.generate, kwargs=inputs) |
|
thread.start() |
|
output = "" |
|
for _output in streamer: |
|
output += _output |
|
yield output |
|
|
|
def enable_next_image(uploaded_images, image): |
|
uploaded_images.append(image) |
|
return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False) |
|
|
|
def add_message(history, message): |
|
if message["files"]: |
|
for file in message["files"]: |
|
history.append([(file,), None]) |
|
if message["text"]: |
|
history.append([message["text"], None]) |
|
return history, gr.MultimodalTextbox(value=None) |
|
|
|
def print_like_dislike(x: gr.LikeData): |
|
print(x.index, x.value, x.liked) |
|
|
|
|
|
def get_chat_images(history): |
|
images = [] |
|
for message in history: |
|
if isinstance(message[0], tuple): |
|
image = load_image(message[0][0]) |
|
images.append(image) |
|
return images |
|
|
|
def get_chat_history(history): |
|
|
|
images = get_chat_images(history) |
|
messages = [] |
|
cur_image_idx = 0 |
|
for i, message in enumerate(history): |
|
if isinstance(message[0], str): |
|
num_images = message[0].count("<image>") |
|
messages.append( |
|
{ |
|
"role": "user", |
|
"content": [] |
|
} |
|
) |
|
print(num_images, cur_image_idx, len(images)) |
|
assert num_images + cur_image_idx <= len(images), f"Number of images uploaded is less than the number of <image> placeholders in the text. Please upload more images." |
|
if num_images > 0: |
|
split_text = message[0].split("<image>") |
|
if split_text[0].strip(): |
|
messages[-1]["content"].append({"type": "text", "text": split_text[0].strip()}) |
|
for idx in range(num_images): |
|
messages[-1]["content"].append({"type": "image"}) |
|
if split_text[idx + 1].strip(): |
|
messages[-1]["content"].append({"type": "text", "text": split_text[idx + 1].strip()}) |
|
else: |
|
messages[-1]["content"].append({"type": "text", "text": message[0]}) |
|
if message[1]: |
|
messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": [{"type": "text", "text": message[1]}] |
|
} |
|
) |
|
elif isinstance(message[0], tuple): |
|
pass |
|
return messages, images |
|
|
|
|
|
def bot(history): |
|
cur_messages = {"text": "", "images": []} |
|
for message in history[::-1]: |
|
if message[1]: |
|
break |
|
if isinstance(message[0], str): |
|
cur_messages["text"] = message[0] + " " + cur_messages["text"] |
|
elif isinstance(message[0], tuple): |
|
cur_messages["images"].extend(message[0]) |
|
cur_messages["text"] = cur_messages["text"].strip() |
|
cur_messages["images"] = cur_messages["images"][::-1] |
|
if not cur_messages["text"]: |
|
raise gr.Error("Please enter a message") |
|
if cur_messages['text'].count("<image>") < len(cur_messages['images']): |
|
gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.") |
|
cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text'] |
|
history[-1][0] = cur_messages["text"] |
|
if cur_messages['text'].count("<image>") > len(cur_messages['images']): |
|
gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.") |
|
cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1] |
|
history[-1][0] = cur_messages["text"] |
|
|
|
chat_history, chat_images = get_chat_history(history) |
|
|
|
generation_kwargs = { |
|
"max_new_tokens": 4096, |
|
"num_beams": 1, |
|
"do_sample": False |
|
} |
|
|
|
response = generate_stream(None, chat_images, chat_history, **generation_kwargs) |
|
for _output in response: |
|
history[-1][1] = _output |
|
time.sleep(0.05) |
|
yield history |
|
|
|
|
|
|
|
def build_demo(): |
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown(""" # Mantis |
|
Mantis is a multimodal conversational AI model that can chat with users about images and text. It's optimized for multi-image reasoning, where inverleaved text and images can be used to generate responses. |
|
|
|
### [Paper](https://arxiv.org/abs/2405.01483) | [Github](https://github.com/TIGER-AI-Lab/Mantis) | [Models](https://huggingface.co/collections/TIGER-Lab/mantis-6619b0834594c878cdb1d6e4) | [Dataset](https://huggingface.co/datasets/TIGER-Lab/Mantis-Instruct) | [Website](https://tiger-ai-lab.github.io/Mantis/) |
|
""") |
|
|
|
gr.Markdown("""## Chat with Mantis |
|
Mantis supports interleaved text-image input format, where you can simply use the placeholder `<image>` to indicate the position of uploaded images. |
|
The model is optimized for multi-image reasoning, while preserving the ability to chat about text and images in a single conversation. |
|
(The model currently serving is [🤗 TIGER-Lab/Mantis-8B-Idefics2](https://huggingface.co/TIGER-Lab/Mantis-8B-Idefics2)) |
|
""") |
|
|
|
chatbot = gr.Chatbot(line_breaks=True) |
|
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True) |
|
|
|
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) |
|
|
|
""" |
|
with gr.Accordion(label='Advanced options', open=False): |
|
temperature = gr.Slider( |
|
label='Temperature', |
|
minimum=0.1, |
|
maximum=2.0, |
|
step=0.1, |
|
value=0.2, |
|
interactive=True |
|
) |
|
top_p = gr.Slider( |
|
label='Top-p', |
|
minimum=0.05, |
|
maximum=1.0, |
|
step=0.05, |
|
value=1.0, |
|
interactive=True |
|
) |
|
""" |
|
|
|
bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response") |
|
|
|
chatbot.like(print_like_dislike, None, None) |
|
|
|
with gr.Row(): |
|
send_button = gr.Button("Send") |
|
clear_button = gr.ClearButton([chatbot, chat_input]) |
|
|
|
send_button.click( |
|
add_message, [chatbot, chat_input], [chatbot, chat_input] |
|
).then( |
|
bot, chatbot, chatbot, api_name="bot_response" |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
{ |
|
"text": "<image> <image> <image> Which image shows a different mood of character from the others?", |
|
"files": ["./examples/image12.jpg", "./examples/image13.jpg", "./examples/image14.jpg"] |
|
}, |
|
{ |
|
"text": "<image> <image> What's the difference between these two images? Please describe as much as you can.", |
|
"files": ["./examples/image1.jpg", "./examples/image2.jpg"] |
|
}, |
|
{ |
|
"text": "<image> <image> Which image shows an older dog?", |
|
"files": ["./examples/image8.jpg", "./examples/image9.jpg"] |
|
}, |
|
{ |
|
"text": "Write a description for the given image sequence in a single paragraph, what is happening in this episode?", |
|
"files": ["./examples/image3.jpg", "./examples/image4.jpg", "./examples/image5.jpg", "./examples/image6.jpg", "./examples/image7.jpg"] |
|
}, |
|
{ |
|
"text": "<image> <image> How many dices are there in image 1 and image 2 respectively?", |
|
"files": ["./examples/image10.jpg", "./examples/image15.jpg"] |
|
}, |
|
], |
|
inputs=[chat_input], |
|
) |
|
|
|
gr.Markdown(""" |
|
## Citation |
|
``` |
|
@article{jiang2024mantis, |
|
title={MANTIS: Interleaved Multi-Image Instruction Tuning}, |
|
author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu}, |
|
journal={arXiv preprint arXiv:2405.01483}, |
|
year={2024} |
|
} |
|
```""") |
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = build_demo() |
|
demo.launch() |