import os import zipfile import torch from torch import nn, optim from torch.utils.data import DataLoader, Dataset from torchvision import transforms from PIL import Image from transformers import CLIPModel, CLIPProcessor import gradio as gr # Ensure PyTorch is installed try: import torch except ModuleNotFoundError: print("PyTorch is not installed. Installing now...") os.system("pip install torch torchvision torchaudio") import torch # Step 1: Unzip the dataset if not os.path.exists("data"): os.makedirs("data") print("Extracting Data.zip...") with zipfile.ZipFile("Data.zip", 'r') as zip_ref: zip_ref.extractall("data") print("Extraction complete.") # Step 2: Dynamically find the 'safe' and 'unsafe' folders def find_dataset_path(root_dir): for root, dirs, files in os.walk(root_dir): if 'safe' in dirs and 'unsafe' in dirs: return root return None # Look for 'safe' and 'unsafe' inside 'data/Data' dataset_path = find_dataset_path("data/Data") if dataset_path is None: print("Debugging extracted structure:") for root, dirs, files in os.walk("data"): print(f"Root: {root}") print(f"Directories: {dirs}") print(f"Files: {files}") raise FileNotFoundError("Expected 'safe' and 'unsafe' folders not found inside 'data/Data'. Please check the Data.zip structure.") print(f"Dataset path found: {dataset_path}") # Step 3: Define Custom Dataset Class class CustomImageDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.image_paths = [] self.labels = [] for label, folder in enumerate(["safe", "unsafe"]): # 0 = safe, 1 = unsafe folder_path = os.path.join(root_dir, folder) if not os.path.exists(folder_path): raise FileNotFoundError(f"Folder '{folder}' not found in '{root_dir}'") for filename in os.listdir(folder_path): if filename.endswith((".jpg", ".jpeg", ".png")): # Only load image files self.image_paths.append(os.path.join(folder_path, filename)) self.labels.append(label) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image_path = self.image_paths[idx] image = Image.open(image_path).convert("RGB") label = self.labels[idx] if self.transform: image = self.transform(image) return image, label # Step 4: Data Transformations transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize to 224x224 pixels transforms.ToTensor(), # Convert to tensor transforms.Normalize((0.5,), (0.5,)), # Normalize image values ]) # Step 5: Load the Dataset train_dataset = CustomImageDataset(dataset_path, transform=transform) train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) # Debugging: Check the dataset print(f"Number of samples in the dataset: {len(train_dataset)}") if len(train_dataset) == 0: raise ValueError("The dataset is empty. Please check if 'Data.zip' is correctly unzipped and contains 'safe' and 'unsafe' folders.") # Step 6: Load Pretrained CLIP Model model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # Add a Classification Layer model.classifier = nn.Linear(model.visual_projection.out_features, 2) # 2 classes: safe, unsafe # Define Optimizer and Loss Function optimizer = optim.Adam(model.classifier.parameters(), lr=1e-4) criterion = nn.CrossEntropyLoss() # Step 7: Fine-Tune the Model model.train() for epoch in range(3): # Number of epochs total_loss = 0 for images, labels in train_loader: optimizer.zero_grad() images = torch.stack([img.to(torch.float32) for img in images]) # Batch of images outputs = model.get_image_features(pixel_values=images) logits = model.classifier(outputs) loss = criterion(logits, labels) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}") # Save the Fine-Tuned Model model.save_pretrained("fine-tuned-model") processor.save_pretrained("fine-tuned-model") print("Model fine-tuned and saved successfully.") # Step 8: Define Gradio Inference Function def classify_image(image, class_names): # Load Fine-Tuned Model model = CLIPModel.from_pretrained("fine-tuned-model") processor = CLIPProcessor.from_pretrained("fine-tuned-model") # Split class names from comma-separated input labels = [label.strip() for label in class_names.split(",") if label.strip()] if not labels: return {"Error": "Please enter at least one valid class name."} # Process the image and labels inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities # Extract labels with their corresponding probabilities result = {label: probs[0][i].item() for i, label in enumerate(labels)} return dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) # Step 9: Set Up Gradio Interface iface = gr.Interface( fn=classify_image, inputs=[ gr.Image(type="pil"), gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe") ], outputs=gr.Label(num_top_classes=2), title="Content Safety Classification", description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.", ) # Launch Gradio Interface if __name__ == "__main__": iface.launch()