File size: 6,360 Bytes
f04732f
a3db70a
c36d5bb
58c422e
a3db70a
 
 
 
f04732f
08bc998
a3db70a
 
b6c6d0c
a3db70a
4d18fd2
a3db70a
 
b6c6d0c
 
a3db70a
c36d5bb
 
 
 
 
 
 
 
 
 
 
 
 
 
88b9346
c36d5bb
 
 
 
 
 
 
 
 
 
 
 
88b9346
c36d5bb
 
 
 
 
 
a3db70a
 
 
b6c6d0c
 
 
 
 
 
 
 
 
 
 
 
 
a3db70a
b6c6d0c
a3db70a
b6c6d0c
 
 
 
 
 
 
 
 
 
 
 
a3db70a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c6d0c
a3db70a
b6c6d0c
 
a3db70a
 
 
 
 
b699eb0
b6c6d0c
 
c36d5bb
 
 
d227c5a
a3db70a
b6c6d0c
b699eb0
b6c6d0c
 
 
 
 
 
 
 
 
a3db70a
 
4be8019
b6c6d0c
a3db70a
 
b6c6d0c
 
 
 
 
a3db70a
b6c6d0c
 
 
 
 
 
 
 
 
 
 
a3db70a
6e20834
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria
from modeling_llava_qwen2 import LlavaQwen2ForCausalLM
from threading import Thread
import re
import time 
from PIL import Image
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

# Initialize tokenizer (doesn't require CUDA)
tokenizer = AutoTokenizer.from_pretrained(
    'qnguyen3/nanoLLaVA-1.5',
    trust_remote_code=True)

# Don't initialize model here - move it to the GPU-decorated function
model = None

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords, tokenizer, input_ids):
        self.keywords = keywords
        self.keyword_ids = []
        self.max_keyword_len = 0
        for keyword in keywords:
            cur_keyword_ids = tokenizer(keyword).input_ids
            if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
                cur_keyword_ids = cur_keyword_ids[1:]
            if len(cur_keyword_ids) > self.max_keyword_len:
                self.max_keyword_len = len(cur_keyword_ids)
            self.keyword_ids.append(torch.tensor(cur_keyword_ids))
        self.tokenizer = tokenizer
        self.start_len = input_ids.shape[1]
        
    def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
        self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
        for keyword_id in self.keyword_ids:
            truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
            if torch.equal(truncated_output_ids, keyword_id):
                return True
        outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
        for keyword in self.keywords:
            if keyword in outputs:
                return True
        return False
        
    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        outputs = []
        for i in range(output_ids.shape[0]):
            outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
        return all(outputs)


@spaces.GPU
def bot_streaming(message, history):
    global model
    
    # Initialize the model inside the GPU-decorated function
    if model is None:
        model = LlavaQwen2ForCausalLM.from_pretrained(
            'qnguyen3/nanoLLaVA-1.5',
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            trust_remote_code=True,
            device_map="auto")  # Use "auto" instead of 'cpu' then manual to('cuda')
    
    # Get image path
    image = None
    if message["files"]:
        image = message["files"][-1]["path"]
    else:
        for i, hist in enumerate(history):
            if type(hist[0])==tuple:
                image = hist[0][0]
                image_turn = i
                break
    
    # Check if image is available
    if image is None:
        return "Please upload an image for LLaVA to work."

    # Prepare conversation messages
    messages = []
    if len(history) > 0 and image is not None:
        messages.append({"role": "user", "content": f'<image>\n{history[1][0]}'})
        messages.append({"role": "assistant", "content": history[1][1] })
        for human, assistant in history[2:]:
            messages.append({"role": "user", "content": human })
            messages.append({"role": "assistant", "content": assistant })
        messages.append({"role": "user", "content": message['text']})
    elif len(history) > 0 and image is None:
        for human, assistant in history:
            messages.append({"role": "user", "content": human })
            messages.append({"role": "assistant", "content": assistant })
        messages.append({"role": "user", "content": message['text']})
    elif len(history) == 0 and image is not None:
        messages.append({"role": "user", "content": f"<image>\n{message['text']}"})
    elif len(history) == 0 and image is None:
        messages.append({"role": "user", "content": message['text'] })

    # Process image
    image = Image.open(image).convert("RGB")
    
    # Prepare input for generation
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True)
    text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
    input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
    
    # Prepare stopping criteria
    stop_str = '<|im_end|>'
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    # Process image and generate text
    image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
    generation_kwargs = dict(
        input_ids=input_ids, 
        images=image_tensor, 
        streamer=streamer, 
        max_new_tokens=512, 
        stopping_criteria=[stopping_criteria], 
        temperature=0.01
    )
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Stream response
    buffer = ""
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer[:]
        time.sleep(0.04)
        yield generated_text_without_prompt


demo = gr.ChatInterface(
    fn=bot_streaming, 
    title="🚀nanoLLaVA-1.5", 
    examples=[
        {"text": "Who is this guy?", "files":["./demo_1.jpg"]},
        {"text": "What does the text say?", "files":["./demo_2.jpeg"]}
    ], 
    description="Try [nanoLLaVA](https://huggingface.co/qnguyen3/nanoLLaVA-1.5) in this demo. Built on top of [Quyen-SE-v0.1](https://huggingface.co/vilm/Quyen-SE-v0.1) (Qwen1.5-0.5B) and [Google SigLIP-400M](https://huggingface.co/google/siglip-so400m-patch14-384). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
    stop_btn="Stop Generation", 
    multimodal=True
)

demo.queue().launch()