Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import os | |
import time | |
from PIL import Image | |
import functools | |
from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava_stream, MLlavaForConditionalGeneration, chat_mllava | |
from models.conversation import conv_templates | |
from typing import List | |
processor = MLlavaProcessor.from_pretrained("remyxai/SpaceMantis") | |
model = LlavaForConditionalGeneration.from_pretrained("remyxai/SpaceMantis") | |
conv_template = conv_templates['llama_3'] | |
def generate_stream(text:str, images:List[Image.Image], history: List[dict], **kwargs): | |
global processor, model | |
model = model.to("cuda") | |
if not images: | |
images = None | |
for text, history in chat_mllava_stream(text, images, model, processor, history=history, **kwargs): | |
yield text | |
return text | |
def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs): | |
global processor, model | |
model = model.to("cuda") | |
if not images: | |
images = None | |
generated_text, history = chat_mllava(text, images, model, processor, history=history, **kwargs) | |
return generated_text | |
def enable_next_image(uploaded_images, image): | |
uploaded_images.append(image) | |
return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False) | |
def add_message(history, message): | |
if message["files"]: | |
for file in message["files"]: | |
history.append([(file,), None]) | |
if message["text"]: | |
history.append([message["text"], None]) | |
return history, gr.MultimodalTextbox(value=None) | |
def print_like_dislike(x: gr.LikeData): | |
print(x.index, x.value, x.liked) | |
def get_chat_history(history): | |
chat_history = [] | |
user_role = conv_template.roles[0] | |
assistant_role = conv_template.roles[1] | |
for i, message in enumerate(history): | |
if isinstance(message[0], str): | |
chat_history.append({"role": user_role, "text": message[0]}) | |
if i != len(history) - 1: | |
assert message[1], "The bot message is not provided, internal error" | |
chat_history.append({"role": assistant_role, "text": message[1]}) | |
else: | |
assert not message[1], "the bot message internal error, get: {}".format(message[1]) | |
chat_history.append({"role": assistant_role, "text": ""}) | |
return chat_history | |
def get_chat_images(history): | |
images = [] | |
for message in history: | |
if isinstance(message[0], tuple): | |
images.extend(message[0]) | |
return images | |
def bot(history): | |
print(history) | |
cur_messages = {"text": "", "images": []} | |
for message in history[::-1]: | |
if message[1]: | |
break | |
if isinstance(message[0], str): | |
cur_messages["text"] = message[0] + " " + cur_messages["text"] | |
elif isinstance(message[0], tuple): | |
cur_messages["images"].extend(message[0]) | |
cur_messages["text"] = cur_messages["text"].strip() | |
cur_messages["images"] = cur_messages["images"][::-1] | |
if not cur_messages["text"]: | |
raise gr.Error("Please enter a message") | |
if cur_messages['text'].count("<image>") < len(cur_messages['images']): | |
gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.") | |
cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text'] | |
history[-1][0] = cur_messages["text"] | |
if cur_messages['text'].count("<image>") > len(cur_messages['images']): | |
gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.") | |
cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1] | |
history[-1][0] = cur_messages["text"] | |
chat_history = get_chat_history(history) | |
chat_images = get_chat_images(history) | |
generation_kwargs = { | |
"max_new_tokens": 4096, | |
"num_beams": 1, | |
"do_sample": False | |
} | |
response = generate_stream(None, chat_images, chat_history, **generation_kwargs) | |
for _output in response: | |
history[-1][1] = _output | |
time.sleep(0.05) | |
yield history | |
def build_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown(""" # SpaceMantis | |
Mantis is a multimodal conversational AI model fine-tuned from [Mantis-8B-siglip-llama3](https://huggingface.co/remyxai/SpaceMantis/blob/main/TIGER-Lab/Mantis-8B-siglip-llama3) for enhanced spatial reasoning. It's optimized for multi-image reasoning, where inverleaved text and images can be used to generate responses. | |
### [Github](https://github.com/remyxai/VQASynth) | [Model](https://huggingface.co/remyxai/SpaceMantis) | [Dataset](https://huggingface.co/datasets/remyxai/mantis-spacellava) | |
""") | |
gr.Markdown("""## Chat with SpaceMantis | |
SpaceMantis supports interleaved text-image input format, where you can simply use the placeholder `<image>` to indicate the position of uploaded images. | |
The model is optimized for multi-image reasoning, while preserving the ability to chat about text and images in a single conversation. | |
(The model currently serving is [🤗 remyxai/SpaceMantis](https://huggingface.co/remyxai/SpaceMantis)) | |
""") | |
chatbot = gr.Chatbot(line_breaks=True) | |
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True) | |
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) | |
""" | |
with gr.Accordion(label='Advanced options', open=False): | |
temperature = gr.Slider( | |
label='Temperature', | |
minimum=0.1, | |
maximum=2.0, | |
step=0.1, | |
value=0.2, | |
interactive=True | |
) | |
top_p = gr.Slider( | |
label='Top-p', | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=1.0, | |
interactive=True | |
) | |
""" | |
bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response") | |
chatbot.like(print_like_dislike, None, None) | |
with gr.Row(): | |
send_button = gr.Button("Send") | |
clear_button = gr.ClearButton([chatbot, chat_input]) | |
send_button.click( | |
add_message, [chatbot, chat_input], [chatbot, chat_input] | |
).then( | |
bot, chatbot, chatbot, api_name="bot_response" | |
) | |
gr.Examples( | |
examples=[ | |
{ | |
"text": "Give me the height of the man in the red hat in feet.", | |
"files": ["./examples/warehouse_rgb.jpg"] | |
}, | |
], | |
inputs=[chat_input], | |
) | |
gr.Markdown(""" | |
## Citation | |
``` | |
@article{chen2024spatialvlm, | |
title = {SpatialVLM: Endowing Vision-Language Models with Spatial Reasoning Capabilities}, | |
author = {Chen, Boyuan and Xu, Zhuo and Kirmani, Sean and Ichter, Brian and Driess, Danny and Florence, Pete and Sadigh, Dorsa and Guibas, Leonidas and Xia, Fei}, | |
journal = {arXiv preprint arXiv:2401.12168}, | |
year = {2024}, | |
url = {https://arxiv.org/abs/2401.12168}, | |
} | |
@article{jiang2024mantis, | |
title={MANTIS: Interleaved Multi-Image Instruction Tuning}, | |
author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu}, | |
journal={arXiv preprint arXiv:2405.01483}, | |
year={2024} | |
} | |
```""") | |
return demo | |
if __name__ == "__main__": | |
demo = build_demo() | |
demo.launch() | |