mgbam commited on
Commit
aecd4d7
·
verified ·
1 Parent(s): 9c2e629

Update hf_models.py

Browse files
Files changed (1) hide show
  1. hf_models.py +64 -106
hf_models.py CHANGED
@@ -4,19 +4,14 @@ import logging
4
  from PIL import Image
5
  import base64
6
  import io
7
- from typing import Optional, Tuple, Dict, Any # Added Dict, Any
8
 
9
- # Assume logger is configured elsewhere, consistent with previous sections
10
  logger = logging.getLogger(__name__)
11
 
12
  # --- Constants ---
13
- # Specify the VQA model identifier from Hugging Face Hub (ensure it supports the Inference API)
14
- # LLaVA is a common choice for general VQA. Verify compatibility and expected payload/response.
15
- HF_VQA_MODEL_ID: str = "llava-hf/llava-1.5-7b-hf"
16
- # HF_VQA_MODEL_ID: str = "Salesforce/blip-vqa-base" # Another example, payload/response differs!
17
-
18
- # Timeout for the API request in seconds. Adjust based on expected model inference time.
19
- HF_API_TIMEOUT: int = 60 # Slightly shorter timeout, adjust if needed
20
 
21
  # --- Helper Functions ---
22
 
@@ -28,24 +23,28 @@ def get_hf_api_token() -> Optional[str]:
28
  The API token string if found, otherwise None.
29
  """
30
  try:
31
- # Access the token defined in Streamlit's secrets management
32
- # (e.g., in secrets.toml or environment variables for deployed apps)
33
  token = st.secrets.get("HF_API_TOKEN")
34
  if token:
35
  logger.debug("Hugging Face API Token retrieved successfully from secrets.")
36
  return token
37
  else:
38
- # Log the absence, but the user-facing warning happens in the main query function
39
  logger.warning("HF_API_TOKEN not found in Streamlit secrets.")
40
  return None
41
  except Exception as e:
42
- # Avoid exposing detailed error related to secrets management to the user here
43
  logger.error(f"Error accessing Streamlit secrets for HF API Token: {e}", exc_info=True)
44
- # The calling function should inform the user about the configuration issue
45
  return None
46
 
47
  def _crop_image_to_roi(image: Image.Image, roi: Dict[str, int]) -> Optional[Image.Image]:
48
- """Crops a PIL Image to the specified ROI dictionary."""
 
 
 
 
 
 
 
 
 
49
  try:
50
  x0, y0 = int(roi['left']), int(roi['top'])
51
  x1, y1 = x0 + int(roi['width']), y0 + int(roi['height'])
@@ -60,10 +59,9 @@ def _crop_image_to_roi(image: Image.Image, roi: Dict[str, int]) -> Optional[Imag
60
  logger.error(f"Failed to crop image to ROI ({roi}): {e}", exc_info=True)
61
  return None
62
 
63
-
64
  def _image_to_base64(image: Image.Image) -> str:
65
  """
66
- Converts a PIL Image object to a base64 encoded string (PNG format).
67
 
68
  Args:
69
  image: The PIL Image object.
@@ -72,59 +70,52 @@ def _image_to_base64(image: Image.Image) -> str:
72
  The base64 encoded string representation of the image.
73
 
74
  Raises:
75
- Exception: If image saving or encoding fails.
76
  """
77
  try:
78
  buffered = io.BytesIO()
79
- image.save(buffered, format="PNG") # PNG is generally preferred for lossless quality
80
  img_byte = buffered.getvalue()
81
  base64_str = base64.b64encode(img_byte).decode("utf-8")
82
  logger.debug(f"Image successfully encoded to base64 string ({len(base64_str)} chars).")
83
  return base64_str
84
  except Exception as e:
85
  logger.error(f"Error during image to base64 conversion: {e}", exc_info=True)
86
- # Re-raise to be caught by the calling function for user feedback
87
  raise Exception(f"Failed to process image for API request: {e}")
88
 
89
-
90
- # Note: Caching API calls is generally complex due to external factors (API status, model updates)
91
- # and potentially dynamic inputs (image content, question). Avoid simple Streamlit caching here.
92
  def query_hf_vqa_inference_api(
93
  image: Image.Image,
94
  question: str,
95
- roi: Optional[Dict[str, int]] = None # Added roi parameter
96
  ) -> Tuple[str, bool]:
97
  """
98
- Queries a specified Hugging Face VQA model via the serverless Inference API.
99
 
100
- Handles API token retrieval, optional ROI cropping, image encoding, request
101
- construction (model-specific payload), API call, and response parsing.
 
102
 
103
  Args:
104
- image: The PIL Image object for analysis.
105
  question: The question to ask about the image.
106
- roi: An optional dictionary defining the region of interest to focus on.
107
- Expected keys: {'left', 'top', 'width', 'height'}.
108
 
109
  Returns:
110
  A tuple containing:
111
- - str: The generated answer string, or an error message prefixed
112
- with "[Fallback Error]" or "[Fallback Unavailable]".
113
- - bool: True if the query was successful and an answer was parsed,
114
- False otherwise.
115
  """
116
  hf_api_token = get_hf_api_token()
117
  if not hf_api_token:
118
- # Return a user-friendly message indicating configuration issue
119
  return "[Fallback Unavailable] Hugging Face API Token not configured.", False
120
 
121
- # Construct the API endpoint URL
122
  api_url = f"https://api-inference.huggingface.co/models/{HF_VQA_MODEL_ID}"
123
  headers = {"Authorization": f"Bearer {hf_api_token}"}
124
 
125
- logger.info(f"Preparing Hugging Face VQA query. Model: {HF_VQA_MODEL_ID}, ROI: {bool(roi)}")
126
 
127
- # --- Prepare Image ---
128
  image_to_send = image
129
  if roi:
130
  cropped_image = _crop_image_to_roi(image, roi)
@@ -132,114 +123,81 @@ def query_hf_vqa_inference_api(
132
  image_to_send = cropped_image
133
  logger.info("Using ROI-cropped image for HF VQA query.")
134
  else:
135
- # Inform user/log that cropping failed, but proceed with full image
136
- logger.warning("Failed to crop image to ROI, proceeding with full image for HF VQA.")
137
- # Optionally, return an error if ROI processing is critical:
138
- # return "[Fallback Error] Failed processing ROI for image.", False
139
 
140
  try:
141
  img_base64 = _image_to_base64(image_to_send)
142
  except Exception as e:
143
- # Error already logged in _image_to_base64
144
- return f"[Fallback Error] {e}", False # Return the error message raised by the helper
145
-
146
- # --- Construct Payload (CRITICAL: Model-Dependent) ---
147
- # The structure of the 'payload' MUST match the specific model's requirements
148
- # as documented on its Hugging Face model card. Examples below.
149
 
150
- # Example Payload for LLaVA models (e.g., llava-hf/llava-1.5-7b-hf):
 
151
  payload = {
152
- "inputs": f"USER: <image>\n{question}\nASSISTANT:", # Prompt includes placeholder and question
153
- "parameters": {"max_new_tokens": 250} # Optional: control output length
154
  }
155
-
156
- # Example Payload for BLIP models (e.g., Salesforce/blip-vqa-base):
157
- # payload = {
158
- # "inputs": {
159
- # "image": img_base64,
160
- # "question": question
161
- # }
162
- # }
163
-
164
- # Example Payload for some other models might just need image bytes directly:
165
- # headers = {"Authorization": f"Bearer {hf_api_token}", "Content-Type": "image/png"} # Different headers!
166
- # payload = image_to_send.tobytes() # Send raw bytes
167
-
168
- logger.debug(f"Sending request to HF VQA API: {api_url}. Payload keys: {list(payload.keys()) if isinstance(payload, dict) else 'Raw Bytes'}")
169
 
170
  # --- Make API Call ---
171
  try:
172
- response = requests.post(api_url, headers=headers, json=payload, timeout=HF_API_TIMEOUT) # Use json=payload for dicts
173
- # For raw bytes payload: requests.post(api_url, headers=headers, data=payload, timeout=HF_API_TIMEOUT)
174
-
175
- response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
176
-
177
  response_data = response.json()
178
- logger.debug(f"HF VQA Raw Response JSON: {response_data}")
179
 
180
- # --- Response Parsing (CRITICAL: Model-Dependent) ---
181
- # Adapt this section based on the JSON structure returned by HF_VQA_MODEL_ID
182
  parsed_answer: Optional[str] = None
183
 
184
- # Example Parsing for LLaVA style response (often a list with generated_text)
185
- if isinstance(response_data, list) and len(response_data) > 0 and "generated_text" in response_data[0]:
186
  full_text = response_data[0]["generated_text"]
187
- # Extract only the generated part after the "ASSISTANT:" marker
188
  assistant_marker = "ASSISTANT:"
189
  if assistant_marker in full_text:
190
- parsed_answer = full_text.split(assistant_marker, 1)[-1].strip()
191
  else:
192
- parsed_answer = full_text.strip() # Fallback if marker isn't found
193
-
194
- # Example Parsing for BLIP style response (dict with "answer")
195
  elif isinstance(response_data, dict) and "answer" in response_data:
196
- parsed_answer = response_data["answer"]
197
-
198
- # Add more 'elif' blocks here for other expected response structures
199
 
200
- # --- Validate and Return Parsed Answer ---
201
- if parsed_answer is not None and parsed_answer.strip():
202
  logger.info(f"Successfully parsed answer from HF VQA ({HF_VQA_MODEL_ID}).")
203
  return parsed_answer.strip(), True
204
  else:
205
- logger.warning(f"HF VQA response received, but failed to parse a valid answer. Response: {response_data}")
206
  return "[Fallback Error] Could not parse a valid answer from the model's response.", False
207
 
208
  except requests.exceptions.Timeout:
209
- error_msg = f"Request to Hugging Face VQA API timed out after {HF_API_TIMEOUT} seconds ({api_url}). The model might be taking too long."
210
  logger.error(error_msg)
211
- return f"[Fallback Error] Request timed out.", False # Keep user message concise
212
  except requests.exceptions.HTTPError as e:
213
  status_code = e.response.status_code
214
  error_detail = ""
215
  try:
216
- # Try to get specific error message from JSON response
217
  error_detail = e.response.json().get('error', e.response.text)
218
- except: # Fallback if response is not JSON or parsing fails
219
  error_detail = e.response.text
220
 
221
- log_message = f"HF API HTTP Error ({status_code}) for {api_url}. Details: {error_detail}"
222
  user_message = f"[Fallback Error] API request failed (Status: {status_code})."
223
 
224
  if status_code == 401:
225
- user_message += " Check Hugging Face API Token configuration."
226
- logger.error(log_message, exc_info=False) # Don't need traceback for auth error
227
  elif status_code == 404:
228
- user_message += f" Check if Model ID '{HF_VQA_MODEL_ID}' is correct and supports Inference API."
229
  logger.error(log_message, exc_info=False)
230
- elif status_code == 503: # Model loading or unavailable
231
- user_message += " The model may be loading, please wait and try again."
232
- logger.warning(log_message, exc_info=False) # Warning, as it might be temporary
233
- else: # Other HTTP errors
234
  user_message += " Please check logs for details."
235
- logger.error(log_message, exc_info=True) # Include traceback for unexpected HTTP errors
236
-
237
  return user_message, False
238
  except requests.exceptions.RequestException as e:
239
- # Catch other network-related errors (DNS, connection refused, etc.)
240
- logger.error(f"Network error during HF API request to {api_url}: {e}", exc_info=True)
241
- return f"[Fallback Error] Network error occurred while contacting the API.", False
242
  except Exception as e:
243
- # Catch-all for any other unexpected errors (e.g., JSON decoding, parsing logic)
244
- logger.error(f"Unexpected error during HF VQA query or response processing: {e}", exc_info=True)
245
- return f"[Fallback Error] An unexpected error occurred during processing.", False
 
4
  from PIL import Image
5
  import base64
6
  import io
7
+ from typing import Optional, Tuple, Dict, Any
8
 
9
+ # Configure logger (assumed to be set up globally in your app)
10
  logger = logging.getLogger(__name__)
11
 
12
  # --- Constants ---
13
+ HF_VQA_MODEL_ID: str = "llava-hf/llava-1.5-7b-hf" # Example model supporting VQA via the Hugging Face Inference API
14
+ HF_API_TIMEOUT: int = 60 # API request timeout in seconds
 
 
 
 
 
15
 
16
  # --- Helper Functions ---
17
 
 
23
  The API token string if found, otherwise None.
24
  """
25
  try:
 
 
26
  token = st.secrets.get("HF_API_TOKEN")
27
  if token:
28
  logger.debug("Hugging Face API Token retrieved successfully from secrets.")
29
  return token
30
  else:
 
31
  logger.warning("HF_API_TOKEN not found in Streamlit secrets.")
32
  return None
33
  except Exception as e:
 
34
  logger.error(f"Error accessing Streamlit secrets for HF API Token: {e}", exc_info=True)
 
35
  return None
36
 
37
  def _crop_image_to_roi(image: Image.Image, roi: Dict[str, int]) -> Optional[Image.Image]:
38
+ """
39
+ Crops a PIL Image to the specified ROI.
40
+
41
+ Args:
42
+ image: The PIL Image object.
43
+ roi: A dictionary with keys 'left', 'top', 'width', and 'height'.
44
+
45
+ Returns:
46
+ A cropped Image if successful, or None if cropping fails.
47
+ """
48
  try:
49
  x0, y0 = int(roi['left']), int(roi['top'])
50
  x1, y1 = x0 + int(roi['width']), y0 + int(roi['height'])
 
59
  logger.error(f"Failed to crop image to ROI ({roi}): {e}", exc_info=True)
60
  return None
61
 
 
62
  def _image_to_base64(image: Image.Image) -> str:
63
  """
64
+ Converts a PIL Image object to a base64 encoded PNG string.
65
 
66
  Args:
67
  image: The PIL Image object.
 
70
  The base64 encoded string representation of the image.
71
 
72
  Raises:
73
+ Exception: If the image encoding fails.
74
  """
75
  try:
76
  buffered = io.BytesIO()
77
+ image.save(buffered, format="PNG")
78
  img_byte = buffered.getvalue()
79
  base64_str = base64.b64encode(img_byte).decode("utf-8")
80
  logger.debug(f"Image successfully encoded to base64 string ({len(base64_str)} chars).")
81
  return base64_str
82
  except Exception as e:
83
  logger.error(f"Error during image to base64 conversion: {e}", exc_info=True)
 
84
  raise Exception(f"Failed to process image for API request: {e}")
85
 
 
 
 
86
  def query_hf_vqa_inference_api(
87
  image: Image.Image,
88
  question: str,
89
+ roi: Optional[Dict[str, int]] = None
90
  ) -> Tuple[str, bool]:
91
  """
92
+ Queries the Hugging Face VQA model via the Inference API.
93
 
94
+ This function handles API token retrieval, optional ROI cropping,
95
+ image encoding, payload construction (model-specific), API call,
96
+ and response parsing.
97
 
98
  Args:
99
+ image: The PIL Image object to analyze.
100
  question: The question to ask about the image.
101
+ roi: An optional dictionary specifying the region of interest.
102
+ Expected keys: 'left', 'top', 'width', 'height'.
103
 
104
  Returns:
105
  A tuple containing:
106
+ - A string with the generated answer or an error message.
107
+ - A boolean indicating success (True) or failure (False).
 
 
108
  """
109
  hf_api_token = get_hf_api_token()
110
  if not hf_api_token:
 
111
  return "[Fallback Unavailable] Hugging Face API Token not configured.", False
112
 
 
113
  api_url = f"https://api-inference.huggingface.co/models/{HF_VQA_MODEL_ID}"
114
  headers = {"Authorization": f"Bearer {hf_api_token}"}
115
 
116
+ logger.info(f"Preparing HF VQA query. Model: {HF_VQA_MODEL_ID}, Using ROI: {bool(roi)}")
117
 
118
+ # --- Prepare Image: Apply ROI if provided ---
119
  image_to_send = image
120
  if roi:
121
  cropped_image = _crop_image_to_roi(image, roi)
 
123
  image_to_send = cropped_image
124
  logger.info("Using ROI-cropped image for HF VQA query.")
125
  else:
126
+ logger.warning("ROI cropping failed; proceeding with full image.")
 
 
 
127
 
128
  try:
129
  img_base64 = _image_to_base64(image_to_send)
130
  except Exception as e:
131
+ return f"[Fallback Error] {e}", False
 
 
 
 
 
132
 
133
+ # --- Construct Payload ---
134
+ # Adjust the payload structure as required by the specific model.
135
  payload = {
136
+ "inputs": f"USER: <image>\n{question}\nASSISTANT:",
137
+ "parameters": {"max_new_tokens": 250}
138
  }
139
+ logger.debug(f"Payload prepared with keys: {list(payload.keys())}")
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  # --- Make API Call ---
142
  try:
143
+ response = requests.post(api_url, headers=headers, json=payload, timeout=HF_API_TIMEOUT)
144
+ response.raise_for_status()
 
 
 
145
  response_data = response.json()
146
+ logger.debug(f"HF VQA API response: {response_data}")
147
 
148
+ # --- Parse Response ---
 
149
  parsed_answer: Optional[str] = None
150
 
151
+ # Example parsing for LLaVA-style responses:
152
+ if isinstance(response_data, list) and response_data and "generated_text" in response_data[0]:
153
  full_text = response_data[0]["generated_text"]
 
154
  assistant_marker = "ASSISTANT:"
155
  if assistant_marker in full_text:
156
+ parsed_answer = full_text.split(assistant_marker, 1)[-1].strip()
157
  else:
158
+ parsed_answer = full_text.strip()
159
+ # Example parsing for BLIP-style responses:
 
160
  elif isinstance(response_data, dict) and "answer" in response_data:
161
+ parsed_answer = response_data["answer"]
 
 
162
 
163
+ if parsed_answer and parsed_answer.strip():
 
164
  logger.info(f"Successfully parsed answer from HF VQA ({HF_VQA_MODEL_ID}).")
165
  return parsed_answer.strip(), True
166
  else:
167
+ logger.warning(f"Response received but no valid answer parsed. Response: {response_data}")
168
  return "[Fallback Error] Could not parse a valid answer from the model's response.", False
169
 
170
  except requests.exceptions.Timeout:
171
+ error_msg = f"Request to HF VQA API timed out after {HF_API_TIMEOUT} seconds."
172
  logger.error(error_msg)
173
+ return "[Fallback Error] Request timed out.", False
174
  except requests.exceptions.HTTPError as e:
175
  status_code = e.response.status_code
176
  error_detail = ""
177
  try:
 
178
  error_detail = e.response.json().get('error', e.response.text)
179
+ except Exception:
180
  error_detail = e.response.text
181
 
182
+ log_message = f"HTTP Error ({status_code}) for {api_url}. Details: {error_detail}"
183
  user_message = f"[Fallback Error] API request failed (Status: {status_code})."
184
 
185
  if status_code == 401:
186
+ user_message += " Check HF API Token configuration."
187
+ logger.error(log_message, exc_info=False)
188
  elif status_code == 404:
189
+ user_message += f" Verify that Model ID '{HF_VQA_MODEL_ID}' is correct."
190
  logger.error(log_message, exc_info=False)
191
+ elif status_code == 503:
192
+ user_message += " The model may be loading; please try again later."
193
+ logger.warning(log_message, exc_info=False)
194
+ else:
195
  user_message += " Please check logs for details."
196
+ logger.error(log_message, exc_info=True)
 
197
  return user_message, False
198
  except requests.exceptions.RequestException as e:
199
+ logger.error(f"Network error during HF API request: {e}", exc_info=True)
200
+ return "[Fallback Error] Network error occurred while contacting the API.", False
 
201
  except Exception as e:
202
+ logger.error(f"Unexpected error during HF VQA query: {e}", exc_info=True)
203
+ return "[Fallback Error] An unexpected error occurred during processing.", False