Content_safety / app.py
Dileep7729's picture
Update app.py
492b8d1 verified
raw
history blame
5.83 kB
import os
import zipfile
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
import gradio as gr
# Ensure PyTorch is installed
try:
import torch
except ModuleNotFoundError:
print("PyTorch is not installed. Installing now...")
os.system("pip install torch torchvision torchaudio")
import torch
# Step 1: Unzip the dataset
if not os.path.exists("data"):
os.makedirs("data")
print("Extracting Data.zip...")
with zipfile.ZipFile("Data.zip", 'r') as zip_ref:
zip_ref.extractall("data")
print("Extraction complete.")
# Step 2: Dynamically find the 'safe' and 'unsafe' folders
def find_dataset_path(root_dir):
for root, dirs, files in os.walk(root_dir):
if 'safe' in dirs and 'unsafe' in dirs:
return root
return None
# Look for 'safe' and 'unsafe' inside 'data/Data'
dataset_path = find_dataset_path("data/Data")
if dataset_path is None:
print("Debugging extracted structure:")
for root, dirs, files in os.walk("data"):
print(f"Root: {root}")
print(f"Directories: {dirs}")
print(f"Files: {files}")
raise FileNotFoundError("Expected 'safe' and 'unsafe' folders not found inside 'data/Data'. Please check the Data.zip structure.")
print(f"Dataset path found: {dataset_path}")
# Step 3: Define Custom Dataset Class
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_paths = []
self.labels = []
for label, folder in enumerate(["safe", "unsafe"]): # 0 = safe, 1 = unsafe
folder_path = os.path.join(root_dir, folder)
if not os.path.exists(folder_path):
raise FileNotFoundError(f"Folder '{folder}' not found in '{root_dir}'")
for filename in os.listdir(folder_path):
if filename.endswith((".jpg", ".jpeg", ".png")): # Only load image files
self.image_paths.append(os.path.join(folder_path, filename))
self.labels.append(label)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = Image.open(image_path).convert("RGB")
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
# Step 4: Data Transformations
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to 224x224 pixels
transforms.ToTensor(), # Convert to tensor
transforms.Normalize((0.5,), (0.5,)), # Normalize image values
])
# Step 5: Load the Dataset
train_dataset = CustomImageDataset(dataset_path, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# Debugging: Check the dataset
print(f"Number of samples in the dataset: {len(train_dataset)}")
if len(train_dataset) == 0:
raise ValueError("The dataset is empty. Please check if 'Data.zip' is correctly unzipped and contains 'safe' and 'unsafe' folders.")
# Step 6: Load Pretrained CLIP Model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Add a Classification Layer
model.classifier = nn.Linear(model.visual_projection.out_features, 2) # 2 classes: safe, unsafe
# Define Optimizer and Loss Function
optimizer = optim.Adam(model.classifier.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# Step 7: Fine-Tune the Model
model.train()
for epoch in range(3): # Number of epochs
total_loss = 0
for images, labels in train_loader:
optimizer.zero_grad()
images = torch.stack([img.to(torch.float32) for img in images]) # Batch of images
outputs = model.get_image_features(pixel_values=images)
logits = model.classifier(outputs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")
# Save the Fine-Tuned Model
model.save_pretrained("fine-tuned-model")
processor.save_pretrained("fine-tuned-model")
print("Model fine-tuned and saved successfully.")
# Step 8: Define Gradio Inference Function
def classify_image(image, class_names):
# Load Fine-Tuned Model
model = CLIPModel.from_pretrained("fine-tuned-model")
processor = CLIPProcessor.from_pretrained("fine-tuned-model")
# Split class names from comma-separated input
labels = [label.strip() for label in class_names.split(",") if label.strip()]
if not labels:
return {"Error": "Please enter at least one valid class name."}
# Process the image and labels
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
# Extract labels with their corresponding probabilities
result = {label: probs[0][i].item() for i, label in enumerate(labels)}
return dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
# Step 9: Set Up Gradio Interface
iface = gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil"),
gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe")
],
outputs=gr.Label(num_top_classes=2),
title="Content Safety Classification",
description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.",
)
# Launch Gradio Interface
if __name__ == "__main__":
iface.launch()