File size: 2,463 Bytes
5dfbd7d
f8ba981
 
32b52f1
 
f8ba981
75bde38
 
5dfbd7d
f8ba981
462afd3
f8ba981
462afd3
36cca9a
f8ba981
 
36cca9a
 
 
 
a99acda
 
36cca9a
a99acda
087cd4e
 
 
 
 
 
 
 
 
 
38de2be
a99acda
 
36cca9a
f8ba981
a99acda
 
36cca9a
f8ba981
a99acda
36cca9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8ba981
a99acda
 
f8ba981
 
a99acda
36cca9a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os

os.system('pip install -q -e .')
os.system('pip uninstall bitsandbytes')
os.system('pip install bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl')

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
print(torch.cuda.is_available()) 

print(os.system('python -m bitsandbytes'))

import gradio as gr
import io
from contextlib import redirect_stdout
import openai
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig
from llava.model import LlavaMistralForCausalLM
from llava.eval.run_llava import eval_model

# LLaVa-Med model setup
model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
kwargs = {"device_map": "auto"}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4'
)
model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

with gr.Blocks(theme=gr.themes.Monochrome()) as app:
    with gr.Column(scale=1):
        gr.Markdown("<center><h1>LLaVa-Med</h1></center>")

        with gr.Row():
            image = gr.Image(type="filepath", scale=2)
            question = gr.Textbox(placeholder="Enter a question", label="Question", scale=3)

        with gr.Row():
            answer = gr.Textbox(placeholder="Answer pops up here", label="Answer", scale=1)

        def run_inference(image, question):
            # Arguments for the model
            args = type('Args', (), {
                "model_path": model_path,
                "model_base": None,
                "image_file": image,
                "query": question,
                "conv_mode": None,
                "sep": ",",
                "temperature": 0,
                "top_p": None,
                "num_beams": 1,
                "max_new_tokens": 512
            })()

            # Capture the printed output of eval_model
            f = io.StringIO()
            with redirect_stdout(f):
                eval_model(args)
            llava_med_result = f.getvalue()
            print(llava_med_result)

            return llava_med_result

        with gr.Row():
            btn = gr.Button("Run Inference", scale=1)

        btn.click(fn=run_inference, inputs=[image, question], outputs=answer)

app.launch(debug=True, height=800, width="100%")