File size: 3,181 Bytes
7988f40
349b5c2
 
 
7988f40
349b5c2
 
 
 
 
7988f40
 
 
 
 
 
349b5c2
7988f40
349b5c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7988f40
349b5c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import subprocess
import uuid

import gradio as gr
import spaces
import torch
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from voyager_index import Voyager

subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

device = "cuda" if torch.cuda.is_available() else "cpu"


# Initialize the model and processor
model = (
    Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
    )
    .to(device)
    .eval()
)

processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True
)


def create_index(session_id):
    return Voyager(embedding_size=1536, override=True, index_name=f"{session_id}")


def add_to_index(files, index):
    index.add_documents([file.name for file in files], batch_size=1)
    return f"Added {len(files)} files to the index."


@spaces.GPU
def query_index(query, index):
    res = index(query, k=1)
    retrieved_image = res["documents"][0][0]["image"]

    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": retrieved_image,
                },
                {"type": "text", "text": query},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(device)
    generated_ids = model.generate(**inputs, max_new_tokens=200)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :]
        for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )

    return output_text[0], retrieved_image


# Define the Gradio interface
with gr.Blocks() as demo:
    session_id = gr.State(lambda: str(uuid.uuid4()))
    index = gr.State(lambda: create_index(session_id.value))

    gr.Markdown("# Full vision pipeline demo")

    with gr.Tab("Add to Index"):
        file_input = gr.File(file_count="multiple", label="Upload Files")
        add_button = gr.Button("Add to Index")
        add_output = gr.Textbox(label="Result")

        add_button.click(add_to_index, inputs=[file_input, index], outputs=add_output)

    with gr.Tab("Query Index"):
        query_input = gr.Textbox(label="Enter your query")
        query_button = gr.Button("Submit Query")
        with gr.Row():
            query_output = gr.Textbox(label="Answer")
            image_output = gr.Image(label="Retrieved Image")

        query_button.click(
            query_index,
            inputs=[query_input, index],
            outputs=[query_output, image_output],
        )

# Launch the interface
demo.launch()