cursor_slides_internvl2 / image_descriptor.py
mknolan's picture
Upload InternVL2 implementation
e59dc66 verified
import os
import torch
import io
from PIL import Image
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor
import logging
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Global variables to store the model and processors
model = None
processor = None
tokenizer = None
def describe_image(image_path=None, image_data=None, show_image=False):
"""
Load an image and generate a description using Qwen2-VL-7B model.
Args:
image_path (str, optional): Path to the image file
image_data (bytes, optional): Raw image data
show_image (bool): Whether to display the image
Returns:
dict: Descriptions of the image
"""
global model, processor, tokenizer
# Initialize model if not already loaded
if model is None or processor is None or tokenizer is None:
load_model()
# Check if we have valid input
if image_path is None and image_data is None:
return {"error": "No image provided"}
try:
# Load the image
if image_path is not None:
if not os.path.exists(image_path):
return {"error": f"Image not found at {image_path}"}
logger.info(f"Processing image from path: {image_path}")
image = Image.open(image_path).convert('RGB')
else:
logger.info("Processing image from uploaded data")
image = Image.open(io.BytesIO(image_data)).convert('RGB')
# Display the image if requested (for local testing only)
if show_image:
plt.figure(figsize=(10, 8))
plt.imshow(image)
plt.axis('off')
if image_path:
plt.title(os.path.basename(image_path))
plt.show()
# Process the image
logger.info("Generating descriptions...")
# Process image and create inputs for various prompts
pixel_values = processor(images=image, return_tensors="pt").to(model.device)
# Basic description prompt
prompt_basic = "Describe this image briefly."
input_ids_basic = tokenizer(prompt_basic, return_tensors="pt").input_ids.to(model.device)
# Detailed description prompt
prompt_detailed = "Analyze this image in detail. Describe the main elements, any text visible, the colors, and the overall composition."
input_ids_detailed = tokenizer(prompt_detailed, return_tensors="pt").input_ids.to(model.device)
# Technical analysis prompt
prompt_technical = "What can you tell me about the technical aspects of this image?"
input_ids_technical = tokenizer(prompt_technical, return_tensors="pt").input_ids.to(model.device)
# Generate outputs for each prompt
# Basic description
with torch.no_grad():
output_basic = model.generate(
input_ids=input_ids_basic,
pixel_values=pixel_values.pixel_values,
max_new_tokens=150,
do_sample=False
)
basic_description = tokenizer.decode(output_basic[0], skip_special_tokens=True).replace(prompt_basic, "").strip()
# Detailed description
with torch.no_grad():
output_detailed = model.generate(
input_ids=input_ids_detailed,
pixel_values=pixel_values.pixel_values,
max_new_tokens=300,
do_sample=False
)
detailed_description = tokenizer.decode(output_detailed[0], skip_special_tokens=True).replace(prompt_detailed, "").strip()
# Technical analysis
with torch.no_grad():
output_technical = model.generate(
input_ids=input_ids_technical,
pixel_values=pixel_values.pixel_values,
max_new_tokens=200,
do_sample=False
)
technical_analysis = tokenizer.decode(output_technical[0], skip_special_tokens=True).replace(prompt_technical, "").strip()
return {
"success": True,
"basic_description": basic_description,
"detailed_description": detailed_description,
"technical_analysis": technical_analysis
}
except Exception as e:
logger.error(f"Error processing image: {str(e)}", exc_info=True)
return {"error": f"Error generating description: {str(e)}"}
def load_model():
"""Load the model and related components"""
global model, processor, tokenizer
try:
logger.info("Loading model...")
model_id = "Qwen/Qwen2-VL-7B"
# Use explicit processor class instead of AutoProcessor
processor = CLIPImageProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load model with 4-bit quantization to reduce memory requirements
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
load_in_4bit=True,
device_map="auto"
)
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}", exc_info=True)
return False
def main():
"""Run in command-line mode"""
# Path to the image
image_folder = "data_temp"
image_name = "page_2.png"
image_path = os.path.join(image_folder, image_name)
# Get the description
result = describe_image(image_path=image_path, show_image=True)
# Print the results
if "error" not in result:
print("\n==== Image Description Results (Qwen2-VL-7B) ====")
print(f"\nBasic Description:\n{result['basic_description']}")
print(f"\nDetailed Description:\n{result['detailed_description']}")
print(f"\nTechnical Analysis:\n{result['technical_analysis']}")
else:
print(result["error"]) # Print error message if there was an issue
if __name__ == "__main__":
main()