Spaces:
Paused
Paused
from argparse import ArgumentParser | |
import copy | |
import gradio as gr | |
from gradio.themes.utils import colors, fonts, sizes | |
from utils.easydict import EasyDict | |
from tasks.eval.model_utils import load_pllava | |
from tasks.eval.eval_utils import ( | |
ChatPllava, | |
conv_plain_v1, | |
Conversation, | |
conv_templates | |
) | |
from tasks.eval.demo import pllava_theme | |
SYSTEM="""You are Pllava, a large vision-language assistant. | |
You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language. | |
Follow the instructions carefully and explain your answers in detail based on the provided video. | |
""" | |
INIT_CONVERSATION: Conversation = conv_plain_v1.copy() | |
# ======================================== | |
# Model Initialization | |
# ======================================== | |
def init_model(args): | |
print('Initializing PLLaVA') | |
model, processor = load_pllava( | |
args.pretrained_model_name_or_path, args.num_frames, | |
use_lora=args.use_lora, | |
weight_dir=args.weight_dir, | |
lora_alpha=args.lora_alpha, | |
use_multi_gpus=args.use_multi_gpus) | |
if not args.use_multi_gpus: | |
model = model.to('cuda') | |
chat = ChatPllava(model, processor) | |
return chat | |
# ======================================== | |
# Gradio Setting | |
# ======================================== | |
def gradio_reset(chat_state, img_list): | |
if chat_state is not None: | |
chat_state = INIT_CONVERSATION.copy() | |
if img_list is not None: | |
img_list = [] | |
return ( | |
None, | |
gr.update(value=None, interactive=True), | |
gr.update(value=None, interactive=True), | |
gr.update(placeholder='Please upload your video first', interactive=False), | |
gr.update(value="Upload & Start Chat", interactive=True), | |
chat_state, | |
img_list | |
) | |
def upload_img(gr_img, gr_video, chat_state=None, num_segments=None, img_list=None): | |
print(gr_img, gr_video) | |
chat_state = INIT_CONVERSATION.copy() if chat_state is None else chat_state | |
img_list = [] if img_list is None else img_list | |
if gr_img is None and gr_video is None: | |
return None, None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None | |
if gr_video: | |
llm_message, img_list, chat_state = chat.upload_video(gr_video, chat_state, img_list, num_segments) | |
return ( | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
gr.update(interactive=True, placeholder='Type and press Enter'), | |
gr.update(value="Start Chatting", interactive=False), | |
chat_state, | |
img_list, | |
) | |
if gr_img: | |
llm_message, img_list,chat_state = chat.upload_img(gr_img, chat_state, img_list) | |
return ( | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
gr.update(interactive=True, placeholder='Type and press Enter'), | |
gr.update(value="Start Chatting", interactive=False), | |
chat_state, | |
img_list | |
) | |
def gradio_ask(user_message, chatbot, chat_state, system): | |
if len(user_message) == 0: | |
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state | |
chat_state = chat.ask(user_message, chat_state, system) | |
chatbot = chatbot + [[user_message, None]] | |
return '', chatbot, chat_state | |
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): | |
llm_message, llm_message_token, chat_state = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=200, num_beams=num_beams, temperature=temperature) | |
llm_message = llm_message.replace("<s>", "") # handle <s> | |
chatbot[-1][1] = llm_message | |
print(chat_state) | |
print(f"Answer: {llm_message}") | |
return chatbot, chat_state, img_list | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--pretrained_model_name_or_path", | |
type=str, | |
required=True, | |
default='llava-hf/llava-1.5-7b-hf' | |
) | |
parser.add_argument( | |
"--num_frames", | |
type=int, | |
required=True, | |
default=4, | |
) | |
parser.add_argument( | |
"--use_lora", | |
action='store_true' | |
) | |
parser.add_argument( | |
"--use_multi_gpus", | |
action='store_true' | |
) | |
parser.add_argument( | |
"--weight_dir", | |
type=str, | |
required=False, | |
default=None, | |
) | |
parser.add_argument( | |
"--conv_mode", | |
type=str, | |
required=False, | |
default=None, | |
) | |
parser.add_argument( | |
"--lora_alpha", | |
type=int, | |
required=False, | |
default=None, | |
) | |
parser.add_argument( | |
"--server_port", | |
type=int, | |
required=False, | |
default=7868, | |
) | |
args = parser.parse_args() | |
return args | |
title = """<h1 align="center"><a href="https://github.com/magic-research/PLLaVA"><img src="https://raw.githubusercontent.com/magic-research/PLLaVA/main/assert/logo.png" alt="PLLAVA" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>""" | |
description = ( | |
"""<br><p><a href='https://github.com/magic-research/PLLaVA'> | |
# PLLAVA! | |
<img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p> | |
- Upload A Video | |
- Press Upload | |
- Start Chatting | |
""" | |
) | |
args = parse_args() | |
model_description = f""" | |
# MODEL INFO | |
- pretrained_model_name_or_path:{args.pretrained_model_name_or_path} | |
- use_lora:{args.use_lora} | |
- weight_dir:{args.weight_dir} | |
""" | |
# with gr.Blocks(title="InternVideo-VideoChat!",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: | |
with gr.Blocks(title="PLLaVA", | |
theme=pllava_theme, | |
css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
gr.Markdown(model_description) | |
with gr.Row(): | |
with gr.Column(scale=0.5, visible=True) as video_upload: | |
# with gr.Column(elem_id="image", scale=0.5) as img_part: | |
with gr.Tab("Video", elem_id='video_tab'): | |
up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload", height=360) | |
with gr.Tab("Image", elem_id='image_tab'): | |
up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload", height=360) | |
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") | |
clear = gr.Button("Restart") | |
# num_segments = gr.Slider( | |
# minimum=8, | |
# maximum=64, | |
# value=8, | |
# step=1, | |
# interactive=True, | |
# label="Video Segments", | |
# ) | |
with gr.Column(visible=True) as input_raws: | |
system_string = gr.Textbox(SYSTEM, interactive=True, label='system') | |
num_beams = gr.Slider( | |
minimum=1, | |
maximum=5, | |
value=1, | |
step=1, | |
interactive=True, | |
label="beam search numbers", | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=1.0, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
) | |
chat_state = gr.State() | |
img_list = gr.State() | |
chatbot = gr.Chatbot(elem_id="chatbot",label='Conversation') | |
with gr.Row(): | |
with gr.Column(scale=0.7): | |
text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False, container=False) | |
with gr.Column(scale=0.15, min_width=0): | |
run = gr.Button("💭Send") | |
with gr.Column(scale=0.15, min_width=0): | |
clear = gr.Button("🔄Clear") | |
with gr.Row(): | |
examples = gr.Examples( | |
examples=[ | |
['example/jesse_dance.mp4', 'What is the man doing?'], | |
['example/yoga.mp4', 'What is the woman doing?'], | |
['example/cooking.mp4', 'Describe the background, characters and the actions in the provided video.'], | |
# ['example/cooking.mp4', 'What is happening in the video?'], | |
['example/working.mp4', 'Describe the background, characters and the actions in the provided video.'], | |
['example/1917.mp4', 'Describe the background, characters and the actions in the provided video.'], | |
], | |
inputs=[up_video, text_input], | |
cache_examples=False | |
) | |
chat = init_model(args) | |
INIT_CONVERSATION = conv_templates[args.conv_mode] | |
upload_button.click(upload_img, [up_image, up_video, chat_state], [up_image, up_video, text_input, upload_button, chat_state, img_list]) | |
text_input.submit(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then( | |
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] | |
) | |
run.click(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then( | |
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] | |
) | |
run.click(lambda: "", None, text_input) | |
clear.click(gradio_reset, [chat_state, img_list], [chatbot, up_image, up_video, text_input, upload_button, chat_state, img_list], queue=False) | |
demo.queue(max_size=5) | |
demo.launch() | |
# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True) | |