Spaces:
Sleeping
Sleeping
File size: 5,830 Bytes
514b8b1 492b8d1 514b8b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
import zipfile
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
import gradio as gr
# Ensure PyTorch is installed
try:
import torch
except ModuleNotFoundError:
print("PyTorch is not installed. Installing now...")
os.system("pip install torch torchvision torchaudio")
import torch
# Step 1: Unzip the dataset
if not os.path.exists("data"):
os.makedirs("data")
print("Extracting Data.zip...")
with zipfile.ZipFile("Data.zip", 'r') as zip_ref:
zip_ref.extractall("data")
print("Extraction complete.")
# Step 2: Dynamically find the 'safe' and 'unsafe' folders
def find_dataset_path(root_dir):
for root, dirs, files in os.walk(root_dir):
if 'safe' in dirs and 'unsafe' in dirs:
return root
return None
# Look for 'safe' and 'unsafe' inside 'data/Data'
dataset_path = find_dataset_path("data/Data")
if dataset_path is None:
print("Debugging extracted structure:")
for root, dirs, files in os.walk("data"):
print(f"Root: {root}")
print(f"Directories: {dirs}")
print(f"Files: {files}")
raise FileNotFoundError("Expected 'safe' and 'unsafe' folders not found inside 'data/Data'. Please check the Data.zip structure.")
print(f"Dataset path found: {dataset_path}")
# Step 3: Define Custom Dataset Class
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_paths = []
self.labels = []
for label, folder in enumerate(["safe", "unsafe"]): # 0 = safe, 1 = unsafe
folder_path = os.path.join(root_dir, folder)
if not os.path.exists(folder_path):
raise FileNotFoundError(f"Folder '{folder}' not found in '{root_dir}'")
for filename in os.listdir(folder_path):
if filename.endswith((".jpg", ".jpeg", ".png")): # Only load image files
self.image_paths.append(os.path.join(folder_path, filename))
self.labels.append(label)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = Image.open(image_path).convert("RGB")
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
# Step 4: Data Transformations
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to 224x224 pixels
transforms.ToTensor(), # Convert to tensor
transforms.Normalize((0.5,), (0.5,)), # Normalize image values
])
# Step 5: Load the Dataset
train_dataset = CustomImageDataset(dataset_path, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# Debugging: Check the dataset
print(f"Number of samples in the dataset: {len(train_dataset)}")
if len(train_dataset) == 0:
raise ValueError("The dataset is empty. Please check if 'Data.zip' is correctly unzipped and contains 'safe' and 'unsafe' folders.")
# Step 6: Load Pretrained CLIP Model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Add a Classification Layer
model.classifier = nn.Linear(model.visual_projection.out_features, 2) # 2 classes: safe, unsafe
# Define Optimizer and Loss Function
optimizer = optim.Adam(model.classifier.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# Step 7: Fine-Tune the Model
model.train()
for epoch in range(3): # Number of epochs
total_loss = 0
for images, labels in train_loader:
optimizer.zero_grad()
images = torch.stack([img.to(torch.float32) for img in images]) # Batch of images
outputs = model.get_image_features(pixel_values=images)
logits = model.classifier(outputs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")
# Save the Fine-Tuned Model
model.save_pretrained("fine-tuned-model")
processor.save_pretrained("fine-tuned-model")
print("Model fine-tuned and saved successfully.")
# Step 8: Define Gradio Inference Function
def classify_image(image, class_names):
# Load Fine-Tuned Model
model = CLIPModel.from_pretrained("fine-tuned-model")
processor = CLIPProcessor.from_pretrained("fine-tuned-model")
# Split class names from comma-separated input
labels = [label.strip() for label in class_names.split(",") if label.strip()]
if not labels:
return {"Error": "Please enter at least one valid class name."}
# Process the image and labels
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
# Extract labels with their corresponding probabilities
result = {label: probs[0][i].item() for i, label in enumerate(labels)}
return dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
# Step 9: Set Up Gradio Interface
iface = gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil"),
gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe")
],
outputs=gr.Label(num_top_classes=2),
title="Content Safety Classification",
description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.",
)
# Launch Gradio Interface
if __name__ == "__main__":
iface.launch()
|