clockclock commited on
Commit
8a64376
·
verified ·
1 Parent(s): 0749470

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +274 -0
app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
4
+ from PIL import Image
5
+ import numpy as np
6
+ from captum.attr import LayerGradCam
7
+ from captum.attr import visualization as viz
8
+ import requests
9
+ from io import BytesIO
10
+ import warnings
11
+ import os
12
+
13
+ # Suppress warnings for cleaner output
14
+ warnings.filterwarnings("ignore")
15
+
16
+ # Force CPU usage for Hugging Face Spaces
17
+ device = torch.device("cpu")
18
+ torch.set_num_threads(1) # Optimize for CPU usage
19
+
20
+ # --- 1. Load Model and Processor ---
21
+ print("Loading model and processor...")
22
+ try:
23
+ model_id = "Organika/sdxl-detector"
24
+ processor = AutoImageProcessor.from_pretrained(model_id)
25
+
26
+ # Load model with CPU-optimized settings
27
+ model = AutoModelForImageClassification.from_pretrained(
28
+ model_id,
29
+ torch_dtype=torch.float32,
30
+ device_map="cpu",
31
+ low_cpu_mem_usage=True
32
+ )
33
+ model.to(device)
34
+ model.eval()
35
+ print("Model and processor loaded successfully on CPU.")
36
+ except Exception as e:
37
+ print(f"Error loading model: {e}")
38
+ raise
39
+
40
+ # --- 2. Define the Explainability (Grad-CAM) Function ---
41
+ def generate_heatmap(image_tensor, original_image, target_class_index):
42
+ try:
43
+ # Ensure tensor is on CPU
44
+ image_tensor = image_tensor.to(device)
45
+
46
+ # Define wrapper function for model forward pass
47
+ def model_forward_wrapper(input_tensor):
48
+ with torch.no_grad(): # Save memory during attribution
49
+ outputs = model(pixel_values=input_tensor)
50
+ return outputs.logits
51
+
52
+ # Get the target layer for Grad-CAM
53
+ # For SWIN transformer, use the layer normalization layer
54
+ target_layer = model.swin.layernorm
55
+
56
+ # Initialize LayerGradCam with the wrapper function
57
+ lgc = LayerGradCam(model_forward_wrapper, target_layer)
58
+
59
+ # Generate attributions
60
+ with torch.no_grad():
61
+ attributions = lgc.attribute(
62
+ image_tensor,
63
+ target=target_class_index,
64
+ relu_attributions=True
65
+ )
66
+
67
+ # Convert attributions to numpy for visualization
68
+ heatmap = np.transpose(
69
+ attributions.squeeze(0).cpu().detach().numpy(),
70
+ (1, 2, 0)
71
+ )
72
+
73
+ # Create visualization
74
+ visualized_image, _ = viz.visualize_image_attr(
75
+ heatmap,
76
+ np.array(original_image),
77
+ method="blended_heat_map",
78
+ sign="all",
79
+ show_colorbar=True,
80
+ title="AI Detection Heatmap",
81
+ alpha_overlay=0.6
82
+ )
83
+
84
+ return visualized_image
85
+
86
+ except Exception as e:
87
+ print(f"Error generating heatmap: {e}")
88
+ # Return original image if heatmap generation fails
89
+ return np.array(original_image)
90
+
91
+ # --- 3. Main Prediction Function ---
92
+ def predict(image_upload: Image.Image, image_url: str):
93
+ try:
94
+ # Determine input source
95
+ if image_upload is not None:
96
+ input_image = image_upload
97
+ print(f"Processing uploaded image of size: {input_image.size}")
98
+ elif image_url and image_url.strip():
99
+ try:
100
+ response = requests.get(image_url, timeout=10)
101
+ response.raise_for_status()
102
+ input_image = Image.open(BytesIO(response.content))
103
+ print(f"Processing image from URL: {image_url}")
104
+ except Exception as e:
105
+ raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}")
106
+ else:
107
+ raise gr.Error("Please upload an image or provide a URL to analyze.")
108
+
109
+ # Convert RGBA to RGB if necessary
110
+ if input_image.mode == 'RGBA':
111
+ input_image = input_image.convert('RGB')
112
+
113
+ # Resize image if too large to save memory
114
+ max_size = 512
115
+ if max(input_image.size) > max_size:
116
+ input_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
117
+
118
+ # Process image
119
+ inputs = processor(images=input_image, return_tensors="pt")
120
+ inputs = {k: v.to(device) for k, v in inputs.items()}
121
+
122
+ # Make prediction
123
+ with torch.no_grad():
124
+ outputs = model(**inputs)
125
+ logits = outputs.logits
126
+
127
+ # Calculate probabilities
128
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
129
+ predicted_class_idx = logits.argmax(-1).item()
130
+ confidence_score = probabilities[0][predicted_class_idx].item()
131
+ predicted_label = model.config.id2label[predicted_class_idx]
132
+
133
+ # Generate explanation
134
+ if predicted_label.lower() == 'ai':
135
+ explanation = (
136
+ f"🤖 The model is {confidence_score:.2%} confident that this image is **AI-GENERATED**.\n\n"
137
+ "The heatmap highlights areas that most influenced this decision. "
138
+ "Red/warm areas indicate regions that appear artificial or AI-generated. "
139
+ "Pay attention to details like skin texture, hair, eyes, or background inconsistencies."
140
+ )
141
+ else:
142
+ explanation = (
143
+ f"👤 The model is {confidence_score:.2%} confident that this image is **HUMAN-MADE**.\n\n"
144
+ "The heatmap shows areas the model considers natural and realistic. "
145
+ "Red/warm areas indicate regions with authentic, human-created characteristics "
146
+ "that AI models typically struggle to replicate perfectly."
147
+ )
148
+
149
+ print("Generating heatmap...")
150
+ heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx)
151
+ print("Heatmap generated successfully.")
152
+
153
+ # Create labels dictionary for gradio output
154
+ labels_dict = {
155
+ model.config.id2label[i]: float(probabilities[0][i])
156
+ for i in range(len(model.config.id2label))
157
+ }
158
+
159
+ return labels_dict, explanation, heatmap_image
160
+
161
+ except Exception as e:
162
+ print(f"Error in prediction: {e}")
163
+ raise gr.Error(f"An error occurred during prediction: {str(e)}")
164
+
165
+ # --- 4. Gradio Interface ---
166
+ with gr.Blocks(
167
+ theme=gr.themes.Soft(),
168
+ title="AI Image Detector",
169
+ css="""
170
+ .gradio-container {
171
+ max-width: 1200px !important;
172
+ }
173
+ .tab-nav {
174
+ margin-bottom: 1rem;
175
+ }
176
+ """
177
+ ) as demo:
178
+ gr.Markdown(
179
+ """
180
+ # 🔍 AI Image Detector with Explainability
181
+
182
+ Determine if an image is AI-generated or human-made using advanced machine learning.
183
+
184
+ **Features:**
185
+ - 🎯 High-accuracy detection using the Organika/sdxl-detector model
186
+ - 🔥 **Heatmap visualization** showing which areas influenced the decision
187
+ - 📱 Support for both file uploads and URL inputs
188
+ - ⚡ Optimized for CPU deployment
189
+
190
+ **How to use:** Upload an image or paste a URL, then click "Analyze Image" to see the results and heatmap.
191
+ """
192
+ )
193
+
194
+ with gr.Row():
195
+ with gr.Column(scale=1):
196
+ gr.Markdown("### 📥 Input")
197
+
198
+ with gr.Tabs():
199
+ with gr.TabItem("📁 Upload File"):
200
+ input_image_upload = gr.Image(
201
+ type="pil",
202
+ label="Upload Your Image",
203
+ height=300
204
+ )
205
+ with gr.TabItem("🔗 Use URL"):
206
+ input_image_url = gr.Textbox(
207
+ label="Paste Image URL here",
208
+ placeholder="https://example.com/image.jpg"
209
+ )
210
+
211
+ submit_btn = gr.Button(
212
+ "🔍 Analyze Image",
213
+ variant="primary",
214
+ size="lg"
215
+ )
216
+
217
+ gr.Markdown(
218
+ """
219
+ ### ℹ️ Tips
220
+ - Supported formats: JPG, PNG, WebP
221
+ - Images are automatically resized for optimal processing
222
+ - For best results, use clear, high-quality images
223
+ """
224
+ )
225
+
226
+ with gr.Column(scale=2):
227
+ gr.Markdown("### 📊 Results")
228
+
229
+ with gr.Row():
230
+ with gr.Column():
231
+ output_label = gr.Label(
232
+ label="Prediction Confidence",
233
+ num_top_classes=2
234
+ )
235
+ with gr.Column():
236
+ output_text = gr.Textbox(
237
+ label="Detailed Explanation",
238
+ lines=6,
239
+ interactive=False
240
+ )
241
+
242
+ output_heatmap = gr.Image(
243
+ label="🔥 AI Detection Heatmap - Red areas influenced the decision most",
244
+ height=400
245
+ )
246
+
247
+ # Connect the interface
248
+ submit_btn.click(
249
+ fn=predict,
250
+ inputs=[input_image_upload, input_image_url],
251
+ outputs=[output_label, output_text, output_heatmap]
252
+ )
253
+
254
+ # Add examples
255
+ gr.Examples(
256
+ examples=[
257
+ ["https://images.unsplash.com/photo-1494790108755-2616b612b786", ""],
258
+ ["https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d", ""],
259
+ ],
260
+ inputs=[input_image_url, input_image_upload],
261
+ outputs=[output_label, output_text, output_heatmap],
262
+ fn=predict,
263
+ cache_examples=False
264
+ )
265
+
266
+ # --- 5. Launch the App ---
267
+ if __name__ == "__main__":
268
+ demo.launch(
269
+ debug=False,
270
+ share=False,
271
+ server_name="0.0.0.0",
272
+ server_port=7860
273
+ )
274
+