File size: 6,262 Bytes
e59dc66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
import torch
import io
from PIL import Image
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor
import logging
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Global variables to store the model and processors
model = None
processor = None
tokenizer = None
def describe_image(image_path=None, image_data=None, show_image=False):
"""
Load an image and generate a description using Qwen2-VL-7B model.
Args:
image_path (str, optional): Path to the image file
image_data (bytes, optional): Raw image data
show_image (bool): Whether to display the image
Returns:
dict: Descriptions of the image
"""
global model, processor, tokenizer
# Initialize model if not already loaded
if model is None or processor is None or tokenizer is None:
load_model()
# Check if we have valid input
if image_path is None and image_data is None:
return {"error": "No image provided"}
try:
# Load the image
if image_path is not None:
if not os.path.exists(image_path):
return {"error": f"Image not found at {image_path}"}
logger.info(f"Processing image from path: {image_path}")
image = Image.open(image_path).convert('RGB')
else:
logger.info("Processing image from uploaded data")
image = Image.open(io.BytesIO(image_data)).convert('RGB')
# Display the image if requested (for local testing only)
if show_image:
plt.figure(figsize=(10, 8))
plt.imshow(image)
plt.axis('off')
if image_path:
plt.title(os.path.basename(image_path))
plt.show()
# Process the image
logger.info("Generating descriptions...")
# Process image and create inputs for various prompts
pixel_values = processor(images=image, return_tensors="pt").to(model.device)
# Basic description prompt
prompt_basic = "Describe this image briefly."
input_ids_basic = tokenizer(prompt_basic, return_tensors="pt").input_ids.to(model.device)
# Detailed description prompt
prompt_detailed = "Analyze this image in detail. Describe the main elements, any text visible, the colors, and the overall composition."
input_ids_detailed = tokenizer(prompt_detailed, return_tensors="pt").input_ids.to(model.device)
# Technical analysis prompt
prompt_technical = "What can you tell me about the technical aspects of this image?"
input_ids_technical = tokenizer(prompt_technical, return_tensors="pt").input_ids.to(model.device)
# Generate outputs for each prompt
# Basic description
with torch.no_grad():
output_basic = model.generate(
input_ids=input_ids_basic,
pixel_values=pixel_values.pixel_values,
max_new_tokens=150,
do_sample=False
)
basic_description = tokenizer.decode(output_basic[0], skip_special_tokens=True).replace(prompt_basic, "").strip()
# Detailed description
with torch.no_grad():
output_detailed = model.generate(
input_ids=input_ids_detailed,
pixel_values=pixel_values.pixel_values,
max_new_tokens=300,
do_sample=False
)
detailed_description = tokenizer.decode(output_detailed[0], skip_special_tokens=True).replace(prompt_detailed, "").strip()
# Technical analysis
with torch.no_grad():
output_technical = model.generate(
input_ids=input_ids_technical,
pixel_values=pixel_values.pixel_values,
max_new_tokens=200,
do_sample=False
)
technical_analysis = tokenizer.decode(output_technical[0], skip_special_tokens=True).replace(prompt_technical, "").strip()
return {
"success": True,
"basic_description": basic_description,
"detailed_description": detailed_description,
"technical_analysis": technical_analysis
}
except Exception as e:
logger.error(f"Error processing image: {str(e)}", exc_info=True)
return {"error": f"Error generating description: {str(e)}"}
def load_model():
"""Load the model and related components"""
global model, processor, tokenizer
try:
logger.info("Loading model...")
model_id = "Qwen/Qwen2-VL-7B"
# Use explicit processor class instead of AutoProcessor
processor = CLIPImageProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load model with 4-bit quantization to reduce memory requirements
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
load_in_4bit=True,
device_map="auto"
)
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}", exc_info=True)
return False
def main():
"""Run in command-line mode"""
# Path to the image
image_folder = "data_temp"
image_name = "page_2.png"
image_path = os.path.join(image_folder, image_name)
# Get the description
result = describe_image(image_path=image_path, show_image=True)
# Print the results
if "error" not in result:
print("\n==== Image Description Results (Qwen2-VL-7B) ====")
print(f"\nBasic Description:\n{result['basic_description']}")
print(f"\nDetailed Description:\n{result['detailed_description']}")
print(f"\nTechnical Analysis:\n{result['technical_analysis']}")
else:
print(result["error"]) # Print error message if there was an issue
if __name__ == "__main__":
main() |