JUNJIE99 commited on
Commit
db7d194
·
verified ·
1 Parent(s): 8067c48

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. demo_test.py +24 -3
  2. modeling_llavanext_for_embedding.py +71 -0
demo_test.py CHANGED
@@ -11,13 +11,34 @@
11
  # print(outputs)
12
 
13
 
 
14
  from transformers import LlavaNextProcessor, AutoModel
15
 
16
  model = AutoModel.from_pretrained("/share/junjie/code/VISTA2/240920mllmemb/llm_dense_retriever/MMRet-release/MMRet-MLLM", trust_remote_code=True).cuda()
 
17
  processor = LlavaNextProcessor.from_pretrained("/share/junjie/code/VISTA2/240920mllmemb/llm_dense_retriever/MMRet-release/MMRet-MLLM")
18
 
19
- texts = "find a image of a dog"
20
 
21
  inputs = processor(texts, return_tensors="pt").to("cuda")
22
- outputs = model(**inputs)
23
- print(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # print(outputs)
12
 
13
 
14
+ import torch
15
  from transformers import LlavaNextProcessor, AutoModel
16
 
17
  model = AutoModel.from_pretrained("/share/junjie/code/VISTA2/240920mllmemb/llm_dense_retriever/MMRet-release/MMRet-MLLM", trust_remote_code=True).cuda()
18
+ model = model.eval()
19
  processor = LlavaNextProcessor.from_pretrained("/share/junjie/code/VISTA2/240920mllmemb/llm_dense_retriever/MMRet-release/MMRet-MLLM")
20
 
21
+ texts = "[INST] \n <instruct> <query> find a image of a dog \n [/INST]"
22
 
23
  inputs = processor(texts, return_tensors="pt").to("cuda")
24
+ outputs = model(**inputs)[:, -1, :]
25
+ embeddings = torch.nn.functional.normalize(outputs, dim=-1)
26
+
27
+ print(embeddings)
28
+
29
+
30
+
31
+ from transformers import LlavaNextProcessor, AutoModel
32
+ import torch
33
+
34
+ model_name = "/share/junjie/code/VISTA2/240920mllmemb/llm_dense_retriever/MMRet-release/MMRet-MLLM"
35
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
36
+ model = model.eval()
37
+ model.set_processor(model_name)
38
+ inputs = model.data_process(text="find a image of a dog", q_or_c="query")
39
+
40
+ model_output = model(**inputs, output_hidden_states=True)
41
+ embeddings = model_output[:, -1, :]
42
+ embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
43
+
44
+ print(embeddings)
modeling_llavanext_for_embedding.py CHANGED
@@ -257,3 +257,74 @@ class LLaVANextForEmbedding(LlavaNextForConditionalGeneration):
257
 
258
  return outputs
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  return outputs
259
 
260
+ def set_processor(self, model_name):
261
+ self.processor = LlavaNextProcessor.from_pretrained(model_name)
262
+ def prepare_text_input(self, image=None, text=None, q_or_c=None, task_instruction=None):
263
+ task_instruction_example_cir = "Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: "
264
+
265
+ assert q_or_c in ["query", "candidate", "q", "c"]
266
+
267
+ if "q" in q_or_c:
268
+ if task_instruction is None:
269
+ text_input = "[INST] \n <instruct> <query>"
270
+ print(f"""
271
+ Warning: For optimal performance, MMRet-MLLM requires the task instruction to be specified in the query.
272
+ For example, for the composed image retrieval task, you might use a specific instruction like: {task_instruction_example_cir}.
273
+ Instructions for other tasks can be referenced in the MMEB benchmark.
274
+ """)
275
+ elif task_instruction is not None:
276
+ text_input = f"[INST] \n <instruct> {task_instruction} <query> "
277
+
278
+ if text is not None:
279
+ text_input = f"{text_input} {text} \n"
280
+ if image is not None:
281
+ text_input = f"{text_input} <image>"
282
+
283
+ text_input = f"{text_input} [/INST]"
284
+ else:
285
+ text_input = "[INST] "
286
+ if text is not None:
287
+ text_input = f"{text_input} {text} \n"
288
+ if image is not None:
289
+ text_input = f"{text_input} <image>"
290
+ text_input = f"{text_input} [/INST]"
291
+
292
+ return text_input
293
+
294
+ def data_process(self, images=None, text=None, q_or_c=None, task_instruction=None):
295
+ if images is not None:
296
+ _is_list = isinstance(images, list)
297
+ elif text is not None:
298
+ _is_list = isinstance(text, list)
299
+ else:
300
+ raise ValueError("images and text cannot be both None.")
301
+
302
+ assert q_or_c in ["query", "candidate", "q", "c"]
303
+
304
+ if not _is_list :
305
+ text_input = self.prepare_text_input(images, text, q_or_c, task_instruction)
306
+ text_input = [text_input]
307
+
308
+ print(text_input)
309
+
310
+ if images is not None:
311
+ images = Image.open(images).resize((512,512)).convert("RGB")
312
+ images = [images]
313
+ inputs = self.processor(images=images, text=text_input, return_tensors="pt", padding=True)
314
+ else:
315
+ inputs = self.processor(text=text_input, return_tensors="pt", padding=True)
316
+
317
+ else:
318
+ text_input = [self.prepare_text_input(_image, _text, q_or_c, task_instruction) for _image, _text in zip(images, text)]
319
+
320
+ print(text_input)
321
+
322
+ if images is not None:
323
+ images = [Image.open(_image).resize((512,512)).convert("RGB") for _image in images]
324
+ inputs = self.processor(images=images, text=text_input, return_tensors="pt", padding=True)
325
+ else:
326
+ inputs = self.processor(text=text_input, return_tensors="pt", padding=True)
327
+
328
+ inputs = inputs.to(self.device)
329
+
330
+ return inputs