File size: 3,664 Bytes
afff347
ea37c27
afff347
ea37c27
ca317b2
d5bf1ae
 
 
 
 
afff347
 
d5bf1ae
 
 
 
afff347
 
d5bf1ae
 
 
ca30e4f
d5bf1ae
 
 
 
ca30e4f
 
afff347
 
ee668ff
d5bf1ae
 
 
 
 
 
 
 
 
ca317b2
d5bf1ae
ca317b2
d5bf1ae
ca317b2
 
d5bf1ae
 
 
ee668ff
d5bf1ae
 
 
afff347
ea37c27
d5bf1ae
afff347
ea37c27
afff347
ea37c27
 
d5bf1ae
afff347
 
 
 
 
ea37c27
afff347
 
 
d5bf1ae
 
 
 
 
 
 
 
 
 
afff347
 
cec0b15
d5bf1ae
afff347
ea37c27
d5bf1ae
5b853cd
ea37c27
d5bf1ae
5b853cd
d5bf1ae
ea37c27
d5bf1ae
 
 
 
7dc477a
d5bf1ae
afff347
5693cbb
afff347
 
d5bf1ae
 
 
 
ee668ff
 
d5bf1ae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import time
from threading import Thread

import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer, TextStreamer

import spaces
import argparse

from llava_llama3.model.builder import load_pretrained_model
from llava_llama3.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava_llama3.conversation import conv_templates, SeparatorStyle
from llava_llama3.utils import disable_torch_init
from llava_llama3.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from llava_llama3.serve.cli import chat_llava

import requests
from io import BytesIO
import base64
import os
import glob
import pandas as pd
from tqdm import tqdm
import json

root_path = os.path.dirname(os.path.abspath(__file__))
print(f'\033[92m{root_path}\033[0m')
os.environ['GRADIO_TEMP_DIR'] = root_path

parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()

# Load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
    args.model_path, 
    None, 
    'llava_llama3', 
    args.load_8bit, 
    args.load_4bit, 
    device=args.device)

@spaces.GPU
def bot_streaming(message, history):
    print(message)
    image_file = None
    if message["files"]:
        if type(message["files"][-1]) == dict:
            image_file = message["files"][-1]["path"]
        else:
            image_file = message["files"][-1]
    else:
        for hist in history:
            if type(hist[0]) == tuple:
                image_file = hist[0][0]
                
    if image_file is None:
        gr.Error("You need to upload an image for LLaVA to work.")
        return
    
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    def generate():
        print('\033[92mRunning chat\033[0m')
        output = chat_llava(
                    args=args,
                    image_file=image_file,
                    text=message['text'],
                    tokenizer=tokenizer,
                    model=llava_model,
                    image_processor=image_processor,
                    context_len=context_len,
                    streamer=streamer)
        return output

    thread = Thread(target=generate)
    thread.start()
    # thread.join()

    buffer = ""
    # output = generate()
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer
        time.sleep(0.06)
        yield generated_text_without_prompt

chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
    gr.ChatInterface(
        fn=bot_streaming,
        title="FinLLaVA Demo",
        examples=[
            {"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]},
        ],
        description="",
        stop_btn="Stop Generation",
        multimodal=True,
        textbox=chat_input,
        chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)