Spaces:
Sleeping
Sleeping
import os | |
import zipfile | |
import torch | |
from torch import nn, optim | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
from PIL import Image | |
from transformers import CLIPModel, CLIPProcessor | |
import gradio as gr | |
# Ensure PyTorch is installed | |
try: | |
import torch | |
except ModuleNotFoundError: | |
print("PyTorch is not installed. Installing now...") | |
os.system("pip install torch torchvision torchaudio") | |
import torch | |
# Step 1: Unzip the dataset | |
if not os.path.exists("data"): | |
os.makedirs("data") | |
print("Extracting Data.zip...") | |
with zipfile.ZipFile("Data.zip", 'r') as zip_ref: | |
zip_ref.extractall("data") | |
print("Extraction complete.") | |
# Step 2: Dynamically find the 'safe' and 'unsafe' folders | |
def find_dataset_path(root_dir): | |
for root, dirs, files in os.walk(root_dir): | |
if 'safe' in dirs and 'unsafe' in dirs: | |
return root | |
return None | |
# Look for 'safe' and 'unsafe' inside 'data/Data' | |
dataset_path = find_dataset_path("data/Data") | |
if dataset_path is None: | |
print("Debugging extracted structure:") | |
for root, dirs, files in os.walk("data"): | |
print(f"Root: {root}") | |
print(f"Directories: {dirs}") | |
print(f"Files: {files}") | |
raise FileNotFoundError("Expected 'safe' and 'unsafe' folders not found inside 'data/Data'. Please check the Data.zip structure.") | |
print(f"Dataset path found: {dataset_path}") | |
# Step 3: Define Custom Dataset Class | |
class CustomImageDataset(Dataset): | |
def __init__(self, root_dir, transform=None): | |
self.root_dir = root_dir | |
self.transform = transform | |
self.image_paths = [] | |
self.labels = [] | |
for label, folder in enumerate(["safe", "unsafe"]): # 0 = safe, 1 = unsafe | |
folder_path = os.path.join(root_dir, folder) | |
if not os.path.exists(folder_path): | |
raise FileNotFoundError(f"Folder '{folder}' not found in '{root_dir}'") | |
for filename in os.listdir(folder_path): | |
if filename.endswith((".jpg", ".jpeg", ".png")): # Only load image files | |
self.image_paths.append(os.path.join(folder_path, filename)) | |
self.labels.append(label) | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
image_path = self.image_paths[idx] | |
image = Image.open(image_path).convert("RGB") | |
label = self.labels[idx] | |
if self.transform: | |
image = self.transform(image) | |
return image, label | |
# Step 4: Data Transformations | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize to 224x224 pixels | |
transforms.ToTensor(), # Convert to tensor | |
transforms.Normalize((0.5,), (0.5,)), # Normalize image values | |
]) | |
# Step 5: Load the Dataset | |
train_dataset = CustomImageDataset(dataset_path, transform=transform) | |
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) | |
# Debugging: Check the dataset | |
print(f"Number of samples in the dataset: {len(train_dataset)}") | |
if len(train_dataset) == 0: | |
raise ValueError("The dataset is empty. Please check if 'Data.zip' is correctly unzipped and contains 'safe' and 'unsafe' folders.") | |
# Step 6: Load Pretrained CLIP Model | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
# Add a Classification Layer | |
model.classifier = nn.Linear(model.visual_projection.out_features, 2) # 2 classes: safe, unsafe | |
# Define Optimizer and Loss Function | |
optimizer = optim.Adam(model.classifier.parameters(), lr=1e-4) | |
criterion = nn.CrossEntropyLoss() | |
# Step 7: Fine-Tune the Model | |
model.train() | |
for epoch in range(3): # Number of epochs | |
total_loss = 0 | |
for images, labels in train_loader: | |
optimizer.zero_grad() | |
images = torch.stack([img.to(torch.float32) for img in images]) # Batch of images | |
outputs = model.get_image_features(pixel_values=images) | |
logits = model.classifier(outputs) | |
loss = criterion(logits, labels) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}") | |
# Save the Fine-Tuned Model | |
model.save_pretrained("fine-tuned-model") | |
processor.save_pretrained("fine-tuned-model") | |
print("Model fine-tuned and saved successfully.") | |
# Step 8: Define Gradio Inference Function | |
def classify_image(image, class_names): | |
# Load Fine-Tuned Model | |
model = CLIPModel.from_pretrained("fine-tuned-model") | |
processor = CLIPProcessor.from_pretrained("fine-tuned-model") | |
# Split class names from comma-separated input | |
labels = [label.strip() for label in class_names.split(",") if label.strip()] | |
if not labels: | |
return {"Error": "Please enter at least one valid class name."} | |
# Process the image and labels | |
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities | |
# Extract labels with their corresponding probabilities | |
result = {label: probs[0][i].item() for i, label in enumerate(labels)} | |
return dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) | |
# Step 9: Set Up Gradio Interface | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe") | |
], | |
outputs=gr.Label(num_top_classes=2), | |
title="Content Safety Classification", | |
description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.", | |
) | |
# Launch Gradio Interface | |
if __name__ == "__main__": | |
iface.launch() | |