LPX55's picture
Update app.py
52ae10e verified
raw
history blame
3.72 kB
import gradio as gr
from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification
from torchvision import transforms
import torch
from PIL import Image
# Ensure using GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the first model and processor
image_processor_1 = AutoImageProcessor.from_pretrained("haywoodsloan/ai-image-detector-deploy")
model_1 = Swinv2ForImageClassification.from_pretrained("haywoodsloan/ai-image-detector-deploy")
model_1 = model_1.to(device)
clf_1 = pipeline(model=model_1, task="image-classification", image_processor=image_processor_1, device=device)
# Load the second model
hfUser = "Heem2"
modelName = "AI-vs-Real-Image-Detection"
clf_2 = pipeline("image-classification", model=f"{hfUser}/{modelName}")
# Define class names for both models
class_names_1 = ['artificial', 'real']
class_names_2 = ['artificial', 'real'] # Adjust if the second model has different classes
def predict_image(img, confidence_threshold):
# Ensure the image is a PIL Image
if not isinstance(img, Image.Image):
raise ValueError(f"Expected a PIL Image, but got {type(img)}")
# Convert the image to RGB if not already
if img.mode != 'RGB':
img_pil = img.convert('RGB')
else:
img_pil = img
# Resize the image
img_pil = transforms.Resize((256, 256))(img_pil)
# Predict using the first model
try:
prediction_1 = clf_1(img_pil)
result_1 = {pred['label']: pred['score'] for pred in prediction_1}
# Ensure the result dictionary contains all class names
for class_name in class_names_1:
if class_name not in result_1:
result_1[class_name] = 0.0
# Check if either class meets the confidence threshold
if result_1['artificial'] >= confidence_threshold:
label_1 = f"Label: artificial, Confidence: {result_1['artificial']:.4f}"
elif result_1['real'] >= confidence_threshold:
label_1 = f"Label: real, Confidence: {result_1['real']:.4f}"
else:
label_1 = "Uncertain Classification"
except Exception as e:
label_1 = f"Error: {str(e)}"
# Predict using the second model
try:
prediction_2 = clf_2(img_pil)
result_2 = {pred['label']: pred['score'] for pred in prediction_2}
# Ensure the result dictionary contains all class names
for class_name in class_names_2:
if class_name not in result_2:
result_2[class_name] = 0.0
# Check if either class meets the confidence threshold
if result_2['artificial'] >= confidence_threshold:
label_2 = f"Label: artificial, Confidence: {result_2['artificial']:.4f}"
elif result_2['real'] >= confidence_threshold:
label_2 = f"Label: real, Confidence: {result_2['real']:.4f}"
else:
label_2 = "Uncertain Classification"
except Exception as e:
label_2 = f"Error: {str(e)}"
# Combine results
combined_results = {
"SwinV2": label_1,
"AI-vs-Real-Image-Detection": label_2
}
return combined_results
# Define the Gradio interface
image = gr.Image(label="Image to Analyze", sources=['upload'], type='pil') # Ensure the image type is PIL
confidence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold")
label = gr.JSON(label="Model Predictions")
gr.Interface(
fn=predict_image,
inputs=[image, confidence_slider],
outputs=label,
title="AI Generated Classification",
queue=True # Enable queuing to handle multiple predictions efficiently
).launch()