File size: 4,138 Bytes
349b5c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from PIL import Image
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration

device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"

min_pixels = 1 * 28 * 28
max_pixels = 256 * 28 * 28  # 2560 * 28 * 28


processor = AutoProcessor.from_pretrained(
    "MrLight/dse-qwen2-2b-mrl-v1", min_pixels=min_pixels, max_pixels=max_pixels
)
model = (
    Qwen2VLForConditionalGeneration.from_pretrained(
        "MrLight/dse-qwen2-2b-mrl-v1",
        # attn_implementation="eager",
        attn_implementation="flash_attention_2"
        if device == "cuda"
        else "eager",  # flash_attn is required but is a pain to install on spaces
        torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
    )
    .to(device)
    .eval()
)
processor.tokenizer.padding_side = "left"
model.padding_side = "left"


def get_embedding(last_hidden_state: torch.Tensor, dimension: int):
    reps = last_hidden_state[:, -1]
    reps = torch.nn.functional.normalize(reps[:, :dimension], p=2, dim=-1)
    return reps.to(torch.float32).cpu().numpy()


def encode_queries(queries: list):
    if isinstance(queries, str):
        queries = [queries]
    query_messages = []
    for query in queries:
        message = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": Image.new("RGB", (28, 28)),
                        "resized_height": 1,
                        "resized_width": 1,
                    },  # need a dummy image here for an easier process.
                    {"type": "text", "text": f"Query: {query}"},
                ],
            }
        ]
        query_messages.append(message)
    query_texts = [
        processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
        + "<|endoftext|>"
        for msg in query_messages
    ]
    query_image_inputs, query_video_inputs = process_vision_info(query_messages)
    query_inputs = processor(
        text=query_texts,
        images=query_image_inputs,
        videos=query_video_inputs,
        padding="longest",
        return_tensors="pt",
    ).to(device)
    query_inputs = model.prepare_inputs_for_generation(**query_inputs, use_cache=False)
    with torch.no_grad():
        output = model(**query_inputs, return_dict=True, output_hidden_states=True)
        query_embeddings = get_embedding(
            output.hidden_states[-1], 1536
        )  # adjust dimensionality for efficiency trade-off, e.g. 512
    return query_embeddings


def encode_images(images: list):
    if isinstance(images, Image.Image):
        images = [images]
    doc_messages = []
    for image in images:
        message = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image,
                    },  #'resized_height':680 , 'resized_width':680} # adjust the image size for efficiency trade-off
                    {"type": "text", "text": "What is shown in this image?"},
                ],
            }
        ]
        doc_messages.append(message)
    doc_texts = [
        processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
        + "<|endoftext|>"
        for msg in doc_messages
    ]
    doc_image_inputs, doc_video_inputs = process_vision_info(doc_messages)
    doc_inputs = processor(
        text=doc_texts,
        images=doc_image_inputs,
        videos=doc_video_inputs,
        padding="longest",
        return_tensors="pt",
    ).to(device)
    doc_inputs = model.prepare_inputs_for_generation(**doc_inputs, use_cache=False)
    output = model(**doc_inputs, return_dict=True, output_hidden_states=True)
    with torch.no_grad():
        output = model(**doc_inputs, return_dict=True, output_hidden_states=True)
    doc_embeddings = get_embedding(
        output.hidden_states[-1], 1536
    )  # adjust dimensionality for efficiency trade-off e.g. 512
    return doc_embeddings