erwannd's picture
Update app.py
eeaab3b verified
import gradio as gr
import spaces
from threading import Thread
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from transformers import TextIteratorStreamer
from PIL import Image
from peft import PeftModel
import requests
import torch, os, re, json
import time
base_model = "llava-hf/llava-v1.6-mistral-7b-hf"
finetune_repo = "erwannd/llava-v1.6-mistral-7b-finetune-combined4k"
processor = LlavaNextProcessor.from_pretrained(base_model)
model = LlavaNextForConditionalGeneration.from_pretrained(
base_model,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
model = PeftModel.from_pretrained(model, finetune_repo)
model.to("cuda:0")
@spaces.GPU
def predict(image, input_text):
image = image.convert("RGB")
prompt = f"[INST] <image>\n{input_text} [/INST]"
inputs = processor(text=prompt, images=image, return_tensors="pt").to(0, torch.float16)
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
# generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=200, do_sample=False)
model.generate(**inputs, streamer=streamer, max_new_tokens=200, do_sample=False)
text_prompt = f"[INST] \n{input_text} [/INST]"
buffer = ""
time.sleep(0.5)
for new_text in streamer:
buffer += new_text
generated_text_without_prompt = buffer[len(text_prompt):]
time.sleep(0.04)
yield generated_text_without_prompt
image = gr.components.Image(type="pil")
input_prompt = gr.components.Textbox(label="Input Prompt")
model_output = gr.components.Textbox(label="Model Output")
examples = [["./examples/bar_m01.png", "Evaluate and explain if this chart is misleading"],
["./examples/bar_n01.png", "Is this chart misleading? Explain"],
["./examples/fox_news_cropped.png", "Tell me if this chart is misleading"],
["./examples/line_m01.png", "Explain if this chart is misleading"],
["./examples/line_m04.png", "Evaluate and explain if this chart is misleading"],
["./examples/pie_m01.png", "Evaluate if this chart is misleading, if so explain"],
["./examples/pie_m02.png", "Is this chart misleading? Explain"]]
description_markdown = """Demo for [LlavaNext](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) finetuned on [charts dataset](https://huggingface.co/datasets/chart-misinformation-detection/bar_line_pie_4k)"""
title = "LlavaNext finetuned on Misleading Chart Dataset"
interface = gr.Interface(
fn=predict,
inputs=[image, input_prompt],
outputs=model_output,
examples=examples,
title=title,
theme='gradio/soft',
cache_examples=False,
description=description_markdown
)
interface.launch()