File size: 6,262 Bytes
e59dc66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import torch
import io
from PIL import Image
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor
import logging
import time

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Global variables to store the model and processors
model = None
processor = None
tokenizer = None

def describe_image(image_path=None, image_data=None, show_image=False):
    """
    Load an image and generate a description using Qwen2-VL-7B model.
    
    Args:
        image_path (str, optional): Path to the image file
        image_data (bytes, optional): Raw image data
        show_image (bool): Whether to display the image
        
    Returns:
        dict: Descriptions of the image
    """
    global model, processor, tokenizer
    
    # Initialize model if not already loaded
    if model is None or processor is None or tokenizer is None:
        load_model()
    
    # Check if we have valid input
    if image_path is None and image_data is None:
        return {"error": "No image provided"}
    
    try:
        # Load the image
        if image_path is not None:
            if not os.path.exists(image_path):
                return {"error": f"Image not found at {image_path}"}
            logger.info(f"Processing image from path: {image_path}")
            image = Image.open(image_path).convert('RGB')
        else:
            logger.info("Processing image from uploaded data")
            image = Image.open(io.BytesIO(image_data)).convert('RGB')
        
        # Display the image if requested (for local testing only)
        if show_image:
            plt.figure(figsize=(10, 8))
            plt.imshow(image)
            plt.axis('off')
            if image_path:
                plt.title(os.path.basename(image_path))
            plt.show()
        
        # Process the image
        logger.info("Generating descriptions...")
        
        # Process image and create inputs for various prompts
        pixel_values = processor(images=image, return_tensors="pt").to(model.device)
        
        # Basic description prompt
        prompt_basic = "Describe this image briefly."
        input_ids_basic = tokenizer(prompt_basic, return_tensors="pt").input_ids.to(model.device)
        
        # Detailed description prompt
        prompt_detailed = "Analyze this image in detail. Describe the main elements, any text visible, the colors, and the overall composition."
        input_ids_detailed = tokenizer(prompt_detailed, return_tensors="pt").input_ids.to(model.device)
        
        # Technical analysis prompt
        prompt_technical = "What can you tell me about the technical aspects of this image?"
        input_ids_technical = tokenizer(prompt_technical, return_tensors="pt").input_ids.to(model.device)
        
        # Generate outputs for each prompt
        # Basic description
        with torch.no_grad():
            output_basic = model.generate(
                input_ids=input_ids_basic,
                pixel_values=pixel_values.pixel_values,
                max_new_tokens=150,
                do_sample=False
            )
        basic_description = tokenizer.decode(output_basic[0], skip_special_tokens=True).replace(prompt_basic, "").strip()
        
        # Detailed description
        with torch.no_grad():
            output_detailed = model.generate(
                input_ids=input_ids_detailed,
                pixel_values=pixel_values.pixel_values,
                max_new_tokens=300,
                do_sample=False
            )
        detailed_description = tokenizer.decode(output_detailed[0], skip_special_tokens=True).replace(prompt_detailed, "").strip()
        
        # Technical analysis
        with torch.no_grad():
            output_technical = model.generate(
                input_ids=input_ids_technical,
                pixel_values=pixel_values.pixel_values,
                max_new_tokens=200,
                do_sample=False
            )
        technical_analysis = tokenizer.decode(output_technical[0], skip_special_tokens=True).replace(prompt_technical, "").strip()
        
        return {
            "success": True,
            "basic_description": basic_description,
            "detailed_description": detailed_description,
            "technical_analysis": technical_analysis
        }
        
    except Exception as e:
        logger.error(f"Error processing image: {str(e)}", exc_info=True)
        return {"error": f"Error generating description: {str(e)}"}

def load_model():
    """Load the model and related components"""
    global model, processor, tokenizer
    
    try:
        logger.info("Loading model...")
        model_id = "Qwen/Qwen2-VL-7B"
        
        # Use explicit processor class instead of AutoProcessor
        processor = CLIPImageProcessor.from_pretrained(model_id)
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        # Load model with 4-bit quantization to reduce memory requirements
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            load_in_4bit=True,
            device_map="auto"
        )
        logger.info("Model loaded successfully")
        return True
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}", exc_info=True)
        return False

def main():
    """Run in command-line mode"""
    # Path to the image
    image_folder = "data_temp"
    image_name = "page_2.png"
    image_path = os.path.join(image_folder, image_name)
    
    # Get the description
    result = describe_image(image_path=image_path, show_image=True)
    
    # Print the results
    if "error" not in result:
        print("\n==== Image Description Results (Qwen2-VL-7B) ====")
        print(f"\nBasic Description:\n{result['basic_description']}")
        print(f"\nDetailed Description:\n{result['detailed_description']}")
        print(f"\nTechnical Analysis:\n{result['technical_analysis']}")
    else:
        print(result["error"])  # Print error message if there was an issue

if __name__ == "__main__":
    main()