|
import os |
|
import torch |
|
import io |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor |
|
import logging |
|
import time |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model = None |
|
processor = None |
|
tokenizer = None |
|
|
|
def describe_image(image_path=None, image_data=None, show_image=False): |
|
""" |
|
Load an image and generate a description using Qwen2-VL-7B model. |
|
|
|
Args: |
|
image_path (str, optional): Path to the image file |
|
image_data (bytes, optional): Raw image data |
|
show_image (bool): Whether to display the image |
|
|
|
Returns: |
|
dict: Descriptions of the image |
|
""" |
|
global model, processor, tokenizer |
|
|
|
|
|
if model is None or processor is None or tokenizer is None: |
|
load_model() |
|
|
|
|
|
if image_path is None and image_data is None: |
|
return {"error": "No image provided"} |
|
|
|
try: |
|
|
|
if image_path is not None: |
|
if not os.path.exists(image_path): |
|
return {"error": f"Image not found at {image_path}"} |
|
logger.info(f"Processing image from path: {image_path}") |
|
image = Image.open(image_path).convert('RGB') |
|
else: |
|
logger.info("Processing image from uploaded data") |
|
image = Image.open(io.BytesIO(image_data)).convert('RGB') |
|
|
|
|
|
if show_image: |
|
plt.figure(figsize=(10, 8)) |
|
plt.imshow(image) |
|
plt.axis('off') |
|
if image_path: |
|
plt.title(os.path.basename(image_path)) |
|
plt.show() |
|
|
|
|
|
logger.info("Generating descriptions...") |
|
|
|
|
|
pixel_values = processor(images=image, return_tensors="pt").to(model.device) |
|
|
|
|
|
prompt_basic = "Describe this image briefly." |
|
input_ids_basic = tokenizer(prompt_basic, return_tensors="pt").input_ids.to(model.device) |
|
|
|
|
|
prompt_detailed = "Analyze this image in detail. Describe the main elements, any text visible, the colors, and the overall composition." |
|
input_ids_detailed = tokenizer(prompt_detailed, return_tensors="pt").input_ids.to(model.device) |
|
|
|
|
|
prompt_technical = "What can you tell me about the technical aspects of this image?" |
|
input_ids_technical = tokenizer(prompt_technical, return_tensors="pt").input_ids.to(model.device) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
output_basic = model.generate( |
|
input_ids=input_ids_basic, |
|
pixel_values=pixel_values.pixel_values, |
|
max_new_tokens=150, |
|
do_sample=False |
|
) |
|
basic_description = tokenizer.decode(output_basic[0], skip_special_tokens=True).replace(prompt_basic, "").strip() |
|
|
|
|
|
with torch.no_grad(): |
|
output_detailed = model.generate( |
|
input_ids=input_ids_detailed, |
|
pixel_values=pixel_values.pixel_values, |
|
max_new_tokens=300, |
|
do_sample=False |
|
) |
|
detailed_description = tokenizer.decode(output_detailed[0], skip_special_tokens=True).replace(prompt_detailed, "").strip() |
|
|
|
|
|
with torch.no_grad(): |
|
output_technical = model.generate( |
|
input_ids=input_ids_technical, |
|
pixel_values=pixel_values.pixel_values, |
|
max_new_tokens=200, |
|
do_sample=False |
|
) |
|
technical_analysis = tokenizer.decode(output_technical[0], skip_special_tokens=True).replace(prompt_technical, "").strip() |
|
|
|
return { |
|
"success": True, |
|
"basic_description": basic_description, |
|
"detailed_description": detailed_description, |
|
"technical_analysis": technical_analysis |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing image: {str(e)}", exc_info=True) |
|
return {"error": f"Error generating description: {str(e)}"} |
|
|
|
def load_model(): |
|
"""Load the model and related components""" |
|
global model, processor, tokenizer |
|
|
|
try: |
|
logger.info("Loading model...") |
|
model_id = "Qwen/Qwen2-VL-7B" |
|
|
|
|
|
processor = CLIPImageProcessor.from_pretrained(model_id) |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
load_in_4bit=True, |
|
device_map="auto" |
|
) |
|
logger.info("Model loaded successfully") |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error loading model: {str(e)}", exc_info=True) |
|
return False |
|
|
|
def main(): |
|
"""Run in command-line mode""" |
|
|
|
image_folder = "data_temp" |
|
image_name = "page_2.png" |
|
image_path = os.path.join(image_folder, image_name) |
|
|
|
|
|
result = describe_image(image_path=image_path, show_image=True) |
|
|
|
|
|
if "error" not in result: |
|
print("\n==== Image Description Results (Qwen2-VL-7B) ====") |
|
print(f"\nBasic Description:\n{result['basic_description']}") |
|
print(f"\nDetailed Description:\n{result['detailed_description']}") |
|
print(f"\nTechnical Analysis:\n{result['technical_analysis']}") |
|
else: |
|
print(result["error"]) |
|
|
|
if __name__ == "__main__": |
|
main() |