cursor_slides_internvl2 / image_descriptor.py
mknolan's picture
Upload InternVL2 implementation
e59dc66 verified
raw
history blame
6.26 kB
import os
import torch
import io
from PIL import Image
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor
import logging
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Global variables to store the model and processors
model = None
processor = None
tokenizer = None
def describe_image(image_path=None, image_data=None, show_image=False):
"""
Load an image and generate a description using Qwen2-VL-7B model.
Args:
image_path (str, optional): Path to the image file
image_data (bytes, optional): Raw image data
show_image (bool): Whether to display the image
Returns:
dict: Descriptions of the image
"""
global model, processor, tokenizer
# Initialize model if not already loaded
if model is None or processor is None or tokenizer is None:
load_model()
# Check if we have valid input
if image_path is None and image_data is None:
return {"error": "No image provided"}
try:
# Load the image
if image_path is not None:
if not os.path.exists(image_path):
return {"error": f"Image not found at {image_path}"}
logger.info(f"Processing image from path: {image_path}")
image = Image.open(image_path).convert('RGB')
else:
logger.info("Processing image from uploaded data")
image = Image.open(io.BytesIO(image_data)).convert('RGB')
# Display the image if requested (for local testing only)
if show_image:
plt.figure(figsize=(10, 8))
plt.imshow(image)
plt.axis('off')
if image_path:
plt.title(os.path.basename(image_path))
plt.show()
# Process the image
logger.info("Generating descriptions...")
# Process image and create inputs for various prompts
pixel_values = processor(images=image, return_tensors="pt").to(model.device)
# Basic description prompt
prompt_basic = "Describe this image briefly."
input_ids_basic = tokenizer(prompt_basic, return_tensors="pt").input_ids.to(model.device)
# Detailed description prompt
prompt_detailed = "Analyze this image in detail. Describe the main elements, any text visible, the colors, and the overall composition."
input_ids_detailed = tokenizer(prompt_detailed, return_tensors="pt").input_ids.to(model.device)
# Technical analysis prompt
prompt_technical = "What can you tell me about the technical aspects of this image?"
input_ids_technical = tokenizer(prompt_technical, return_tensors="pt").input_ids.to(model.device)
# Generate outputs for each prompt
# Basic description
with torch.no_grad():
output_basic = model.generate(
input_ids=input_ids_basic,
pixel_values=pixel_values.pixel_values,
max_new_tokens=150,
do_sample=False
)
basic_description = tokenizer.decode(output_basic[0], skip_special_tokens=True).replace(prompt_basic, "").strip()
# Detailed description
with torch.no_grad():
output_detailed = model.generate(
input_ids=input_ids_detailed,
pixel_values=pixel_values.pixel_values,
max_new_tokens=300,
do_sample=False
)
detailed_description = tokenizer.decode(output_detailed[0], skip_special_tokens=True).replace(prompt_detailed, "").strip()
# Technical analysis
with torch.no_grad():
output_technical = model.generate(
input_ids=input_ids_technical,
pixel_values=pixel_values.pixel_values,
max_new_tokens=200,
do_sample=False
)
technical_analysis = tokenizer.decode(output_technical[0], skip_special_tokens=True).replace(prompt_technical, "").strip()
return {
"success": True,
"basic_description": basic_description,
"detailed_description": detailed_description,
"technical_analysis": technical_analysis
}
except Exception as e:
logger.error(f"Error processing image: {str(e)}", exc_info=True)
return {"error": f"Error generating description: {str(e)}"}
def load_model():
"""Load the model and related components"""
global model, processor, tokenizer
try:
logger.info("Loading model...")
model_id = "Qwen/Qwen2-VL-7B"
# Use explicit processor class instead of AutoProcessor
processor = CLIPImageProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load model with 4-bit quantization to reduce memory requirements
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
load_in_4bit=True,
device_map="auto"
)
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}", exc_info=True)
return False
def main():
"""Run in command-line mode"""
# Path to the image
image_folder = "data_temp"
image_name = "page_2.png"
image_path = os.path.join(image_folder, image_name)
# Get the description
result = describe_image(image_path=image_path, show_image=True)
# Print the results
if "error" not in result:
print("\n==== Image Description Results (Qwen2-VL-7B) ====")
print(f"\nBasic Description:\n{result['basic_description']}")
print(f"\nDetailed Description:\n{result['detailed_description']}")
print(f"\nTechnical Analysis:\n{result['technical_analysis']}")
else:
print(result["error"]) # Print error message if there was an issue
if __name__ == "__main__":
main()