import requests import streamlit as st import logging from PIL import Image import base64 import io from typing import Optional, Tuple, Dict, Any # Configure logger (assumed to be set up globally in your app) logger = logging.getLogger(__name__) # --- Constants --- HF_VQA_MODEL_ID: str = "llava-hf/llava-1.5-7b-hf" # Example model supporting VQA via the Hugging Face Inference API HF_API_TIMEOUT: int = 60 # API request timeout in seconds # --- Helper Functions --- def get_hf_api_token() -> Optional[str]: """ Retrieves the Hugging Face API Token securely from Streamlit secrets. Returns: The API token string if found, otherwise None. """ try: token = st.secrets.get("HF_API_TOKEN") if token: logger.debug("Hugging Face API Token retrieved successfully from secrets.") return token else: logger.warning("HF_API_TOKEN not found in Streamlit secrets.") return None except Exception as e: logger.error(f"Error accessing Streamlit secrets for HF API Token: {e}", exc_info=True) return None def _crop_image_to_roi(image: Image.Image, roi: Dict[str, int]) -> Optional[Image.Image]: """ Crops a PIL Image to the specified ROI. Args: image: The PIL Image object. roi: A dictionary with keys 'left', 'top', 'width', and 'height'. Returns: A cropped Image if successful, or None if cropping fails. """ try: x0, y0 = int(roi['left']), int(roi['top']) x1, y1 = x0 + int(roi['width']), y0 + int(roi['height']) box = (x0, y0, x1, y1) cropped_img = image.crop(box) logger.debug(f"Cropped image to ROI box: {box}") return cropped_img except KeyError as e: logger.error(f"ROI dictionary is missing required key: {e}") return None except Exception as e: logger.error(f"Failed to crop image to ROI ({roi}): {e}", exc_info=True) return None def _image_to_base64(image: Image.Image) -> str: """ Converts a PIL Image object to a base64 encoded PNG string. Args: image: The PIL Image object. Returns: The base64 encoded string representation of the image. Raises: Exception: If the image encoding fails. """ try: buffered = io.BytesIO() image.save(buffered, format="PNG") img_byte = buffered.getvalue() base64_str = base64.b64encode(img_byte).decode("utf-8") logger.debug(f"Image successfully encoded to base64 string ({len(base64_str)} chars).") return base64_str except Exception as e: logger.error(f"Error during image to base64 conversion: {e}", exc_info=True) raise Exception(f"Failed to process image for API request: {e}") def query_hf_vqa_inference_api( image: Image.Image, question: str, roi: Optional[Dict[str, int]] = None ) -> Tuple[str, bool]: """ Queries the Hugging Face VQA model via the Inference API. This function handles API token retrieval, optional ROI cropping, image encoding, payload construction (model-specific), API call, and response parsing. Args: image: The PIL Image object to analyze. question: The question to ask about the image. roi: An optional dictionary specifying the region of interest. Expected keys: 'left', 'top', 'width', 'height'. Returns: A tuple containing: - A string with the generated answer or an error message. - A boolean indicating success (True) or failure (False). """ hf_api_token = get_hf_api_token() if not hf_api_token: return "[Fallback Unavailable] Hugging Face API Token not configured.", False api_url = f"https://api-inference.huggingface.co/models/{HF_VQA_MODEL_ID}" headers = {"Authorization": f"Bearer {hf_api_token}"} logger.info(f"Preparing HF VQA query. Model: {HF_VQA_MODEL_ID}, Using ROI: {bool(roi)}") # --- Prepare Image: Apply ROI if provided --- image_to_send = image if roi: cropped_image = _crop_image_to_roi(image, roi) if cropped_image: image_to_send = cropped_image logger.info("Using ROI-cropped image for HF VQA query.") else: logger.warning("ROI cropping failed; proceeding with full image.") try: img_base64 = _image_to_base64(image_to_send) except Exception as e: return f"[Fallback Error] {e}", False # --- Construct Payload --- # Adjust the payload structure as required by the specific model. payload = { "inputs": f"USER: \n{question}\nASSISTANT:", "parameters": {"max_new_tokens": 250} } logger.debug(f"Payload prepared with keys: {list(payload.keys())}") # --- Make API Call --- try: response = requests.post(api_url, headers=headers, json=payload, timeout=HF_API_TIMEOUT) response.raise_for_status() response_data = response.json() logger.debug(f"HF VQA API response: {response_data}") # --- Parse Response --- parsed_answer: Optional[str] = None # Example parsing for LLaVA-style responses: if isinstance(response_data, list) and response_data and "generated_text" in response_data[0]: full_text = response_data[0]["generated_text"] assistant_marker = "ASSISTANT:" if assistant_marker in full_text: parsed_answer = full_text.split(assistant_marker, 1)[-1].strip() else: parsed_answer = full_text.strip() # Example parsing for BLIP-style responses: elif isinstance(response_data, dict) and "answer" in response_data: parsed_answer = response_data["answer"] if parsed_answer and parsed_answer.strip(): logger.info(f"Successfully parsed answer from HF VQA ({HF_VQA_MODEL_ID}).") return parsed_answer.strip(), True else: logger.warning(f"Response received but no valid answer parsed. Response: {response_data}") return "[Fallback Error] Could not parse a valid answer from the model's response.", False except requests.exceptions.Timeout: error_msg = f"Request to HF VQA API timed out after {HF_API_TIMEOUT} seconds." logger.error(error_msg) return "[Fallback Error] Request timed out.", False except requests.exceptions.HTTPError as e: status_code = e.response.status_code error_detail = "" try: error_detail = e.response.json().get('error', e.response.text) except Exception: error_detail = e.response.text log_message = f"HTTP Error ({status_code}) for {api_url}. Details: {error_detail}" user_message = f"[Fallback Error] API request failed (Status: {status_code})." if status_code == 401: user_message += " Check HF API Token configuration." logger.error(log_message, exc_info=False) elif status_code == 404: user_message += f" Verify that Model ID '{HF_VQA_MODEL_ID}' is correct." logger.error(log_message, exc_info=False) elif status_code == 503: user_message += " The model may be loading; please try again later." logger.warning(log_message, exc_info=False) else: user_message += " Please check logs for details." logger.error(log_message, exc_info=True) return user_message, False except requests.exceptions.RequestException as e: logger.error(f"Network error during HF API request: {e}", exc_info=True) return "[Fallback Error] Network error occurred while contacting the API.", False except Exception as e: logger.error(f"Unexpected error during HF VQA query: {e}", exc_info=True) return "[Fallback Error] An unexpected error occurred during processing.", False