manu commited on
Commit
9b1e831
β€’
1 Parent(s): f7bffc9

Update everything a bit !

Browse files
Files changed (1) hide show
  1. app.py +18 -31
app.py CHANGED
@@ -3,7 +3,6 @@ import spaces
3
 
4
  import gradio as gr
5
  import torch
6
- from colpali_engine.models.paligemma_colbert_architecture import ColPali
7
  from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
8
  from colpali_engine.utils.colpali_processing_utils import (
9
  process_images,
@@ -13,19 +12,18 @@ from pdf2image import convert_from_path
13
  from PIL import Image
14
  from torch.utils.data import DataLoader
15
  from tqdm import tqdm
16
- from transformers import AutoProcessor
17
 
18
- # Load model
19
- model_name = "vidore/colpali-v1.2"
20
- token = os.environ.get("HF_TOKEN")
21
- model = ColPali.from_pretrained(
22
- "vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token = token).eval()
23
 
24
- model.load_adapter(model_name)
25
- model = model.eval()
26
- processor = AutoProcessor.from_pretrained(model_name, token = token)
27
 
28
- mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
 
 
 
 
 
 
 
29
 
30
 
31
  @spaces.GPU
@@ -37,15 +35,13 @@ def search(query: str, ds, images, k):
37
 
38
  qs = []
39
  with torch.no_grad():
40
- batch_query = process_queries(processor, [query], mock_image)
41
- batch_query = {k: v.to(device) for k, v in batch_query.items()}
42
  embeddings_query = model(**batch_query)
43
  qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
44
 
45
- retriever_evaluator = CustomEvaluator(is_multi_vector=True)
46
- scores = retriever_evaluator.evaluate(qs, ds)
47
 
48
- top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
49
 
50
  results = []
51
  for idx in top_k_indices:
@@ -75,21 +71,19 @@ def convert_files(files):
75
  @spaces.GPU
76
  def index_gpu(images, ds):
77
  """Example script to run inference with ColPali"""
78
-
 
 
 
 
79
  # run inference - docs
80
  dataloader = DataLoader(
81
  images,
82
  batch_size=4,
83
  shuffle=False,
84
- collate_fn=lambda x: process_images(processor, x),
85
  )
86
 
87
-
88
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
89
- if device != model.device:
90
- model.to(device)
91
-
92
-
93
  for batch_doc in tqdm(dataloader):
94
  with torch.no_grad():
95
  batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
@@ -98,8 +92,6 @@ def index_gpu(images, ds):
98
  return f"Uploaded and converted {len(images)} pages", ds, images
99
 
100
 
101
- def get_example():
102
- return [[["climate_youth_magazine.pdf"], "How much tropical forest is cut annually ?"]]
103
 
104
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
105
  gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models πŸ“š")
@@ -128,11 +120,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
128
  query = gr.Textbox(placeholder="Enter your query here", label="Query")
129
  k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
130
 
131
- # with gr.Row():
132
- # gr.Examples(
133
- # examples=get_example(),
134
- # inputs=[file, query],
135
- # )
136
 
137
  # Define the actions
138
  search_button = gr.Button("πŸ” Search", variant="primary")
 
3
 
4
  import gradio as gr
5
  import torch
 
6
  from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
7
  from colpali_engine.utils.colpali_processing_utils import (
8
  process_images,
 
12
  from PIL import Image
13
  from torch.utils.data import DataLoader
14
  from tqdm import tqdm
 
15
 
16
+ from colpali_engine.models import ColQwen2, ColQwen2Processor
 
 
 
 
17
 
 
 
 
18
 
19
+
20
+ model = ColQwen2.from_pretrained(
21
+ "manu/colqwen2-v1.0-alpha",
22
+ torch_dtype=torch.bfloat16,
23
+ device_map="cuda:0", # or "mps" if on Apple Silicon
24
+ ).eval()
25
+ processor = ColQwen2Processor.from_pretrained("manu/colqwen2-v1.0-alpha")
26
+
27
 
28
 
29
  @spaces.GPU
 
35
 
36
  qs = []
37
  with torch.no_grad():
38
+ batch_query = processor.process_queries([query]).to(model.device)
 
39
  embeddings_query = model(**batch_query)
40
  qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
41
 
42
+ scores = processor.score(qs, ds, device=device)
 
43
 
44
+ top_k_indices = scores[0].topk(k).indices.tolist()
45
 
46
  results = []
47
  for idx in top_k_indices:
 
71
  @spaces.GPU
72
  def index_gpu(images, ds):
73
  """Example script to run inference with ColPali"""
74
+
75
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
76
+ if device != model.device:
77
+ model.to(device)
78
+
79
  # run inference - docs
80
  dataloader = DataLoader(
81
  images,
82
  batch_size=4,
83
  shuffle=False,
84
+ collate_fn=lambda x: processor.process_images(x).to(model.device),
85
  )
86
 
 
 
 
 
 
 
87
  for batch_doc in tqdm(dataloader):
88
  with torch.no_grad():
89
  batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
 
92
  return f"Uploaded and converted {len(images)} pages", ds, images
93
 
94
 
 
 
95
 
96
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
97
  gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models πŸ“š")
 
120
  query = gr.Textbox(placeholder="Enter your query here", label="Query")
121
  k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
122
 
 
 
 
 
 
123
 
124
  # Define the actions
125
  search_button = gr.Button("πŸ” Search", variant="primary")