Visual Document Retrieval
Transformers
Safetensors
ColPali
English
pretraining
adrish commited on
Commit
9931aed
·
1 Parent(s): 0d89ced

updated the code

Browse files
Files changed (1) hide show
  1. handler.py +44 -27
handler.py CHANGED
@@ -9,74 +9,91 @@ from typing import Dict, Any, List
9
  class EndpointHandler:
10
  def __init__(self, model_path: str = None):
11
  """
12
- Initialize the endpoint handler using the ColPali retrieval model.
13
  If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'.
14
  """
15
  if model_path is None:
16
  model_path = os.path.dirname(os.path.realpath(__file__))
17
  try:
 
18
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- # Use the specialized ColPaliForRetrieval class.
20
  self.model = ColPaliForRetrieval.from_pretrained(
21
  model_path,
22
  device_map="cuda" if torch.cuda.is_available() else "cpu",
23
  trust_remote_code=True,
24
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
25
  ).to(self.device)
26
- # Use the specialized ColPaliProcessor.
27
  self.processor = ColPaliProcessor.from_pretrained(model_path, trust_remote_code=True)
28
  except Exception as e:
29
  raise RuntimeError(f"Error loading model or processor: {e}")
30
 
31
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
32
  """
33
- Process the input data, run inference using the ColPali retrieval model,
34
- and return the outputs.
35
 
36
  Expects a dictionary with an "inputs" key containing a list of dictionaries.
37
- Each dictionary should have:
38
- - "image": a base64-encoded image string.
39
- - "prompt": (optional) a text prompt (default is used if missing).
40
  """
41
  try:
42
  inputs_list = data.get("inputs", [])
43
- config = data.get("config", {})
44
-
45
  if not inputs_list or not isinstance(inputs_list, list):
46
- return {"error": "Inputs should be a list of dictionaries with 'image' and optionally 'prompt' keys."}
47
 
48
  images: List[Image.Image] = []
49
- texts: List[str] = []
50
-
51
  for item in inputs_list:
52
  image_b64 = item.get("image")
53
  if not image_b64:
54
  return {"error": "One of the input items is missing 'image' data."}
55
  try:
56
- # Decode the base64-encoded image and convert to RGB.
57
  image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB")
58
  images.append(image)
59
  except Exception as e:
60
  return {"error": f"Failed to decode one of the images: {e}"}
61
- # Use the provided prompt or a default prompt.
62
- prompt = item.get("prompt", "Describe the image content in detail.")
63
- texts.append(prompt)
64
 
65
- # Prepare inputs with the ColPali processor.
66
  model_inputs = self.processor(
67
  images=images,
68
- text=texts,
69
- padding=True,
70
  return_tensors="pt",
 
71
  ).to(self.device)
72
 
73
- # For retrieval, we call the model directly rather than using generate().
74
- outputs = self.model(**model_inputs)
75
- # Assuming that the model returns logits or retrieval scores,
76
- # we extract and convert them to lists.
77
- retrieval_scores = outputs.logits.tolist() if hasattr(outputs, "logits") else outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- return {"responses": retrieval_scores}
80
 
81
  except Exception as e:
82
  return {"error": f"Unexpected error: {e}"}
@@ -87,7 +104,7 @@ _service = EndpointHandler()
87
  def handle(data, context):
88
  """
89
  Entry point for the Hugging Face dedicated inference endpoint.
90
- Processes the input data and returns the model's outputs.
91
  """
92
  try:
93
  if data is None:
 
9
  class EndpointHandler:
10
  def __init__(self, model_path: str = None):
11
  """
12
+ Initialize the endpoint handler using the ColPali model for OCR extraction.
13
  If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'.
14
  """
15
  if model_path is None:
16
  model_path = os.path.dirname(os.path.realpath(__file__))
17
  try:
18
+ # Use GPU if available, otherwise CPU.
19
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ # Load the specialized ColPali model (designed for retrieval but repurposed here for OCR generation).
21
  self.model = ColPaliForRetrieval.from_pretrained(
22
  model_path,
23
  device_map="cuda" if torch.cuda.is_available() else "cpu",
24
  trust_remote_code=True,
25
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
26
  ).to(self.device)
27
+ # Load the processor that handles image preprocessing.
28
  self.processor = ColPaliProcessor.from_pretrained(model_path, trust_remote_code=True)
29
  except Exception as e:
30
  raise RuntimeError(f"Error loading model or processor: {e}")
31
 
32
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
33
  """
34
+ Process the input data for OCR extraction.
 
35
 
36
  Expects a dictionary with an "inputs" key containing a list of dictionaries.
37
+ Each dictionary must have an "image" key with a base64-encoded image string.
38
+ For OCR extraction, no text prompt is provided.
 
39
  """
40
  try:
41
  inputs_list = data.get("inputs", [])
 
 
42
  if not inputs_list or not isinstance(inputs_list, list):
43
+ return {"error": "Inputs should be a list of dictionaries with an 'image' key."}
44
 
45
  images: List[Image.Image] = []
 
 
46
  for item in inputs_list:
47
  image_b64 = item.get("image")
48
  if not image_b64:
49
  return {"error": "One of the input items is missing 'image' data."}
50
  try:
51
+ # Decode the base64 string and convert to an RGB PIL image.
52
  image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB")
53
  images.append(image)
54
  except Exception as e:
55
  return {"error": f"Failed to decode one of the images: {e}"}
 
 
 
56
 
57
+ # Process only images with the processor (to avoid the text+image conflict).
58
  model_inputs = self.processor(
59
  images=images,
 
 
60
  return_tensors="pt",
61
+ padding=True,
62
  ).to(self.device)
63
 
64
+ # Manually create a dummy text prompt by inserting a beginning-of-sequence token.
65
+ # This is necessary to trigger text generation even though no prompt is provided.
66
+ bos_token_id = (
67
+ self.processor.tokenizer.bos_token_id
68
+ or self.processor.tokenizer.cls_token_id
69
+ or self.processor.tokenizer.pad_token_id
70
+ )
71
+ if bos_token_id is None:
72
+ raise RuntimeError("No BOS token found in the tokenizer.")
73
+ batch_size = model_inputs["pixel_values"].shape[0]
74
+ dummy_input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long).to(self.device)
75
+ model_inputs["input_ids"] = dummy_input_ids
76
+
77
+ # Generation parameters (can be overridden via the "config" field).
78
+ config = data.get("config", {})
79
+ max_new_tokens = config.get("max_new_tokens", 256)
80
+ temperature = config.get("temperature", 0.8)
81
+ num_return_sequences = config.get("num_return_sequences", 1)
82
+ do_sample = bool(config.get("do_sample", True))
83
+
84
+ # Call generate on the model using the image-only inputs augmented with the dummy text.
85
+ outputs = self.model.generate(
86
+ **model_inputs,
87
+ max_new_tokens=max_new_tokens,
88
+ temperature=temperature,
89
+ num_return_sequences=num_return_sequences,
90
+ do_sample=do_sample,
91
+ )
92
+
93
+ # Decode generated tokens into text using the processor's tokenizer.
94
+ text_output = self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
95
 
96
+ return {"responses": text_output}
97
 
98
  except Exception as e:
99
  return {"error": f"Unexpected error: {e}"}
 
104
  def handle(data, context):
105
  """
106
  Entry point for the Hugging Face dedicated inference endpoint.
107
+ Processes input data and returns the extracted OCR text.
108
  """
109
  try:
110
  if data is None: