File size: 3,452 Bytes
ea37c27
 
ca317b2
 
ea37c27
ca317b2
ea37c27
8c54553
ed5a7bf
ee668ff
a7191f1
 
 
 
 
 
 
 
 
ca317b2
a7191f1
ca317b2
a7191f1
ca317b2
 
a7191f1
 
 
ee668ff
386e329
ea37c27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec0b15
 
 
 
 
 
1d234d5
cec0b15
 
 
 
 
 
 
 
 
 
 
 
 
ea37c27
 
 
cec0b15
 
 
 
 
 
 
 
 
 
ea37c27
 
 
 
 
 
 
 
 
f07fb5d
ea37c27
 
 
 
 
 
 
ee668ff
 
 
ea37c27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import time
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image
import argparse
import spaces

parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()

# load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
    args.model_path, 
    None, 
    'llava_llama3', 
    args.load_8bit, 
    args.load_4bit, 
    device=args.device)

@spaces.GPU
def bot_streaming(message, history):
    print(message)
    image = None
    
    # Check if there's an image in the current message
    if message["files"]:
        # message["files"][-1] could be a dictionary or a string
        if isinstance(message["files"][-1], dict):
            image = message["files"][-1]["path"]
        else:
            image = message["files"][-1]
    else:
        # If no image in the current message, look in the history for the last image
        for hist in history:
            if isinstance(hist[0], tuple):
                image = hist[0][0]
    
    # Error handling if no image is found
    if image is None:
        raise gr.Error("You need to upload an image for LLaVA to work.")
    
    # Load the image
    image = Image.open(image)
    
    # Generate the prompt for the model
    prompt = message['text']
    
    # Use a streamer to generate the output in a streaming fashion
    streamer = []

    # Define a function to call chat_llava in a separate thread
    def generate_output():
        output = chat_llava(
            args=args,
            image_file=image,
            text=prompt,
            tokenizer=tokenizer,
            model=llava_model,
            image_processor=image_processor,
            context_len=context_len
        )
        for new_text in output:
            streamer.append(new_text)
    
    # Start the generation in a separate thread
    thread = Thread(target=generate_output)
    thread.start()
    
    # Stream the output
    buffer = ""
    while thread.is_alive() or streamer:
        while streamer:
            new_text = streamer.pop(0)
            buffer += new_text
            yield buffer
        time.sleep(0.1)
    
    # Ensure any remaining text is yielded after the thread completes
    while streamer:
        new_text = streamer.pop(0)
        buffer += new_text
        yield buffer


chatbot=gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
    gr.ChatInterface(
    fn=bot_streaming,
    title="FinLLaVA",
    examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
              {"text": "How to make this pastry?", "files": ["./baklava.png"]}],

    stop_btn="Stop Generation",
    multimodal=True,
    textbox=chat_input,
    chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)