File size: 3,327 Bytes
02532a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image

# Print system information
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Load a smaller model that should work even with limited resources
model_id = "Salesforce/blip-image-captioning-base"  # ~1 GB model, very reliable
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Create global variables for model and processor
processor = None
model = None

def load_model():
    global processor, model
    try:
        print("Loading model and processor...")
        processor = AutoProcessor.from_pretrained(model_id)
        model = AutoModelForVision2Seq.from_pretrained(model_id).to(device)
        print("Model loaded successfully")
        return True
    except Exception as e:
        print(f"Error loading model: {e}")
        return False

def analyze_image(image):
    # If model not loaded yet, try to load it
    global processor, model
    if model is None:
        success = load_model()
        if not success:
            return "Failed to load model. Check logs for details."
    
    try:
        if isinstance(image, str):
            # If image is a filepath
            image = Image.open(image).convert('RGB')
        elif not isinstance(image, Image.Image):
            # If image is numpy array (from gradio)
            image = Image.fromarray(image).convert('RGB')
            
        # Process image
        inputs = processor(images=image, return_tensors="pt").to(device)
        
        # Generate caption
        with torch.no_grad():
            output = model.generate(**inputs, max_length=100)
        
        # Decode caption
        caption = processor.decode(output[0], skip_special_tokens=True)
        
        # Get device information
        if device == "cuda":
            memory_info = torch.cuda.memory_allocated() / 1024**2
            return f"Caption: {caption}\n\nUsing device: {device} ({torch.cuda.get_device_name(0)})\nGPU memory used: {memory_info:.2f} MB"
        else:
            return f"Caption: {caption}\n\nUsing device: {device}"
            
    except Exception as e:
        print(f"Error during inference: {e}")
        return f"Error during inference: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Simple GPU Test") as demo:
    gr.Markdown("# Simple GPU Test with BLIP Image Captioning")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload an image")
            submit_btn = gr.Button("Generate Caption")
            
            # Show if GPU is available
            if torch.cuda.is_available():
                gr.Markdown(f"✅ **GPU detected**: {torch.cuda.get_device_name(0)}")
            else:
                gr.Markdown("❌ **No GPU detected**. Running on CPU.")
                
        with gr.Column():
            output_text = gr.Textbox(label="Result", lines=5)
    
    submit_btn.click(
        fn=analyze_image,
        inputs=[image_input],
        outputs=[output_text]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")