|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
import os |
|
import time |
|
import numpy as np |
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
|
|
|
|
|
from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration |
|
|
|
|
|
CLIP_MODEL_ID = "openai/clip-vit-base-patch32" |
|
DETAILED_MODEL_ID = "Salesforce/blip-image-captioning-large" |
|
USE_GPU = torch.cuda.is_available() |
|
|
|
|
|
clip_model = None |
|
clip_processor = None |
|
detailed_model = None |
|
detailed_processor = None |
|
|
|
def load_clip_model(): |
|
"""Load the CLIP model for fast classification""" |
|
global clip_model, clip_processor |
|
|
|
|
|
if clip_model is not None and clip_processor is not None: |
|
return True |
|
|
|
print("Loading CLIP model...") |
|
try: |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID) |
|
|
|
|
|
if USE_GPU: |
|
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to("cuda") |
|
else: |
|
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID) |
|
|
|
|
|
clip_model.eval() |
|
print("CLIP model loaded successfully!") |
|
return True |
|
except Exception as e: |
|
print(f"Error loading CLIP model: {str(e)}") |
|
return False |
|
|
|
def load_detailed_model(): |
|
"""Load the BLIP model for detailed image analysis""" |
|
global detailed_model, detailed_processor |
|
|
|
|
|
if detailed_model is not None and detailed_processor is not None: |
|
return True |
|
|
|
print("Loading BLIP model...") |
|
try: |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
detailed_processor = BlipProcessor.from_pretrained(DETAILED_MODEL_ID) |
|
|
|
|
|
|
|
detailed_model = BlipForConditionalGeneration.from_pretrained( |
|
DETAILED_MODEL_ID, |
|
torch_dtype=torch.float16 if USE_GPU else torch.float32 |
|
) |
|
|
|
|
|
if USE_GPU: |
|
detailed_model = detailed_model.to("cuda") |
|
|
|
|
|
detailed_model.eval() |
|
print("BLIP model loaded successfully!") |
|
return True |
|
except Exception as e: |
|
print(f"Error loading BLIP model: {str(e)}") |
|
if "CUDA out of memory" in str(e): |
|
print("Not enough GPU memory for the detailed model") |
|
return False |
|
|
|
|
|
CATEGORIES = [ |
|
"a photograph", "a painting", "a drawing", "a digital art", |
|
"landscape", "portrait", "cityscape", "animals", "food", "vehicle", |
|
"building", "nature", "people", "abstract art", "technology", |
|
"interior", "exterior", "night scene", "beach", "mountains", |
|
"forest", "water", "flowers", "sports", |
|
"a person", "multiple people", "a child", "an elderly person", |
|
"a dog", "a cat", "wildlife", "a bird", "a car", "a building", |
|
"a presentation slide", "a graph", "a chart", "a diagram", "text document", |
|
"a screenshot", "a map", "a table of data", "a scientific figure" |
|
] |
|
|
|
def get_detailed_analysis(image): |
|
"""Get detailed analysis from the image using BLIP model""" |
|
try: |
|
start_time = time.time() |
|
|
|
|
|
if not load_detailed_model(): |
|
return "Couldn't load detailed analysis model." |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
image_pil = Image.fromarray(image).convert('RGB') |
|
else: |
|
|
|
image_pil = image.convert('RGB') |
|
|
|
|
|
max_size = 600 |
|
width, height = image_pil.size |
|
if max(width, height) > max_size: |
|
if width > height: |
|
new_width = max_size |
|
new_height = int(height * (max_size / width)) |
|
else: |
|
new_height = max_size |
|
new_width = int(width * (max_size / height)) |
|
image_pil = image_pil.resize((new_width, new_height), Image.LANCZOS) |
|
|
|
device = "cuda" if USE_GPU else "cpu" |
|
|
|
|
|
inputs = detailed_processor(image_pil, return_tensors="pt") |
|
if USE_GPU: |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
|
|
output_ids = detailed_model.generate( |
|
**inputs, |
|
max_length=50, |
|
num_beams=5, |
|
do_sample=False, |
|
early_stopping=True |
|
) |
|
base_description = detailed_processor.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
analyses = { |
|
"text": None, |
|
"chart": None, |
|
"subject": None |
|
} |
|
|
|
|
|
ultra_simple_prompts = { |
|
f"Text in {base_description[:20]}...": "text", |
|
f"Charts in {base_description[:20]}...": "chart", |
|
f"Subject of {base_description[:20]}...": "subject" |
|
} |
|
|
|
for prompt, analysis_type in ultra_simple_prompts.items(): |
|
|
|
inputs = detailed_processor(image_pil, text=prompt, return_tensors="pt") |
|
|
|
if USE_GPU: |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
output_ids = detailed_model.generate( |
|
**inputs, |
|
max_length=75, |
|
num_beams=3, |
|
do_sample=True, |
|
temperature=0.7, |
|
repetition_penalty=1.2, |
|
early_stopping=True |
|
) |
|
|
|
result = detailed_processor.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
colon_parts = result.split(":") |
|
if len(colon_parts) > 1: |
|
|
|
result = ":".join(colon_parts[1:]).strip() |
|
|
|
|
|
if base_description in result: |
|
result = result.replace(base_description, "").strip() |
|
|
|
|
|
for p in ultra_simple_prompts.keys(): |
|
if p in result: |
|
result = result.replace(p, "").strip() |
|
|
|
|
|
if base_description[:20] in result: |
|
result = result.replace(base_description[:20], "").strip() |
|
|
|
|
|
remove_patterns = [ |
|
"text in", "charts in", "subject of", |
|
"in detail", "describe", "this image", "the image", |
|
"can you", "do you", "is there", "are there", "i can see", |
|
"i see", "there is", "there are", "it looks like", |
|
"appears to be", "seems to be", "might be", "could be", |
|
"i think", "i believe", "probably", "possibly", "maybe", |
|
"it is", "this is", "that is", "these are", "those are", |
|
"image shows", "picture shows", "image contains", "picture contains", |
|
"in the image", "in this image", "of this image", "from this image", |
|
"based on", "according to", "looking at", "from what i can see", |
|
"appears to show", "depicts", "represents", "illustrates", "demonstrates", |
|
"presents", "displays", "portrays", "reveals", "indicates", "suggests", |
|
"we can see", "you can see", "one can see" |
|
] |
|
|
|
for pattern in remove_patterns: |
|
if pattern.lower() in result.lower(): |
|
|
|
lower_result = result.lower() |
|
while pattern.lower() in lower_result: |
|
idx = lower_result.find(pattern.lower()) |
|
if idx >= 0: |
|
result = result[:idx] + result[idx+len(pattern):] |
|
lower_result = result.lower() |
|
|
|
|
|
result = result.strip() |
|
while result and result[0] in ",.;:?!-": |
|
result = result[1:].strip() |
|
|
|
|
|
result = result.replace("...", "").strip() |
|
|
|
|
|
if result and len(result) > 0: |
|
result = result[0].upper() + result[1:] if len(result) > 1 else result[0].upper() |
|
|
|
analyses[analysis_type] = result |
|
|
|
|
|
output_text = f"## Detailed Description\n{base_description}\n\n" |
|
|
|
|
|
if analyses['text'] and len(analyses['text']) > 5 and not any(x in analyses['text'].lower() for x in ["no text", "not any text", "can't see", "cannot see", "don't see", "couldn't find"]): |
|
output_text += f"## Text Content\n{analyses['text']}\n\n" |
|
|
|
if analyses['chart'] and len(analyses['chart']) > 5 and not any(x in analyses['chart'].lower() for x in ["no chart", "not any chart", "no graph", "not any graph", "can't see", "cannot see", "don't see", "couldn't find"]): |
|
output_text += f"## Chart Analysis\n{analyses['chart']}\n\n" |
|
|
|
output_text += f"## Main Subject\n{analyses['subject'] or 'Unable to determine main subject.'}" |
|
|
|
|
|
if USE_GPU: |
|
torch.cuda.empty_cache() |
|
|
|
elapsed_time = time.time() - start_time |
|
return output_text |
|
|
|
except Exception as e: |
|
print(f"Error in detailed analysis: {str(e)}") |
|
|
|
if USE_GPU: |
|
torch.cuda.empty_cache() |
|
return f"Error in detailed analysis: {str(e)}" |
|
|
|
def get_clip_classification(image): |
|
"""Get fast classification using CLIP""" |
|
if not load_clip_model(): |
|
return [] |
|
|
|
try: |
|
|
|
inputs = clip_processor( |
|
text=CATEGORIES, |
|
images=image, |
|
return_tensors="pt", |
|
padding=True |
|
) |
|
|
|
|
|
if USE_GPU: |
|
inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
|
|
|
|
with torch.inference_mode(): |
|
outputs = clip_model(**inputs) |
|
|
|
|
|
logits_per_image = outputs.logits_per_image |
|
probs = logits_per_image.softmax(dim=1) |
|
|
|
|
|
values, indices = probs[0].topk(8) |
|
|
|
|
|
return [(CATEGORIES[idx], value.item() * 100) for value, idx in zip(values, indices)] |
|
except Exception as e: |
|
print(f"Error in CLIP classification: {str(e)}") |
|
return [] |
|
|
|
def process_image(image, get_detailed=False): |
|
"""Process image with both fast and detailed analysis""" |
|
if image is None: |
|
return "Please upload an image to analyze." |
|
|
|
try: |
|
|
|
start_time = time.time() |
|
|
|
|
|
if hasattr(image, 'mode') and image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
if max(image.size) > 600: |
|
ratio = 600 / max(image.size) |
|
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) |
|
image = image.resize(new_size, Image.LANCZOS) |
|
|
|
|
|
categories = get_clip_classification(image) |
|
|
|
result = "## Image Classification\n" |
|
result += "This image appears to contain:\n" |
|
for category, confidence in categories: |
|
result += f"- {category.title()} ({confidence:.1f}%)\n" |
|
|
|
|
|
if get_detailed: |
|
result += "\n## Detailed Analysis\n" |
|
detailed_result = get_detailed_analysis(image) |
|
result += detailed_result |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
result += f"\n\nAnalysis completed in {elapsed_time:.2f} seconds." |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return result |
|
|
|
except Exception as e: |
|
print(f"Error: {str(e)}") |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
return f"Error processing image: {str(e)}" |
|
|
|
|
|
with gr.Blocks(title="Enhanced Image Analyzer") as demo: |
|
gr.Markdown("# Enhanced Image Analyzer") |
|
gr.Markdown("Upload an image and choose between fast classification or detailed analysis.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil", label="Upload an image") |
|
detailed_checkbox = gr.Checkbox(label="Get detailed analysis (slower but better quality)", value=False) |
|
analyze_btn = gr.Button("Analyze Image", variant="primary") |
|
|
|
with gr.Column(): |
|
output = gr.Markdown(label="Analysis Results") |
|
|
|
analyze_btn.click( |
|
fn=process_image, |
|
inputs=[input_image, detailed_checkbox], |
|
outputs=output |
|
) |
|
|
|
|
|
if os.path.exists("data_temp"): |
|
examples = [os.path.join("data_temp", f) for f in os.listdir("data_temp") |
|
if f.endswith(('.png', '.jpg', '.jpeg'))] |
|
if examples: |
|
gr.Examples(examples=examples, inputs=input_image) |
|
|
|
if __name__ == "__main__": |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |