File size: 2,760 Bytes
4decb51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2c33d5
 
 
4decb51
 
 
 
 
 
 
 
eeaab3b
 
4decb51
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import spaces
from threading import Thread

from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from transformers import TextIteratorStreamer
from PIL import Image
from peft import PeftModel
import requests
import torch, os, re, json
import time


base_model = "llava-hf/llava-v1.6-mistral-7b-hf"
finetune_repo = "erwannd/llava-v1.6-mistral-7b-finetune-combined4k"

processor = LlavaNextProcessor.from_pretrained(base_model)

model = LlavaNextForConditionalGeneration.from_pretrained(
    base_model,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)
model = PeftModel.from_pretrained(model, finetune_repo)
model.to("cuda:0")


@spaces.GPU
def predict(image, input_text):
    image = image.convert("RGB")
    prompt = f"[INST] <image>\n{input_text} [/INST]"
    
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(0, torch.float16)
    
    streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
    # generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=200, do_sample=False)

    model.generate(**inputs, streamer=streamer, max_new_tokens=200, do_sample=False)

    text_prompt = f"[INST]  \n{input_text} [/INST]"

    buffer = ""
    time.sleep(0.5)
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer[len(text_prompt):]
        time.sleep(0.04)
        yield generated_text_without_prompt

   
image = gr.components.Image(type="pil")
input_prompt = gr.components.Textbox(label="Input Prompt")
model_output = gr.components.Textbox(label="Model Output")
examples = [["./examples/bar_m01.png", "Evaluate and explain if this chart is misleading"],
            ["./examples/bar_n01.png", "Is this chart misleading? Explain"],
            ["./examples/fox_news_cropped.png", "Tell me if this chart is misleading"],
            ["./examples/line_m01.png", "Explain if this chart is misleading"],
            ["./examples/line_m04.png", "Evaluate and explain if this chart is misleading"],
            ["./examples/pie_m01.png", "Evaluate if this chart is misleading, if so explain"],
            ["./examples/pie_m02.png", "Is this chart misleading? Explain"]]


description_markdown = """Demo for [LlavaNext](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) finetuned on [charts dataset](https://huggingface.co/datasets/chart-misinformation-detection/bar_line_pie_4k)"""

title = "LlavaNext finetuned on Misleading Chart Dataset"
interface = gr.Interface(
    fn=predict, 
    inputs=[image, input_prompt], 
    outputs=model_output, 
    examples=examples, 
    title=title,
    theme='gradio/soft',
    cache_examples=False,
    description=description_markdown
)

interface.launch()