File size: 5,830 Bytes
514b8b1
 
 
 
 
 
 
 
 
 
492b8d1
 
 
 
 
 
 
 
514b8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import zipfile
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
import gradio as gr

# Ensure PyTorch is installed
try:
    import torch
except ModuleNotFoundError:
    print("PyTorch is not installed. Installing now...")
    os.system("pip install torch torchvision torchaudio")
    import torch

# Step 1: Unzip the dataset
if not os.path.exists("data"):
    os.makedirs("data")

print("Extracting Data.zip...")
with zipfile.ZipFile("Data.zip", 'r') as zip_ref:
    zip_ref.extractall("data")
print("Extraction complete.")

# Step 2: Dynamically find the 'safe' and 'unsafe' folders
def find_dataset_path(root_dir):
    for root, dirs, files in os.walk(root_dir):
        if 'safe' in dirs and 'unsafe' in dirs:
            return root
    return None

# Look for 'safe' and 'unsafe' inside 'data/Data'
dataset_path = find_dataset_path("data/Data")
if dataset_path is None:
    print("Debugging extracted structure:")
    for root, dirs, files in os.walk("data"):
        print(f"Root: {root}")
        print(f"Directories: {dirs}")
        print(f"Files: {files}")
    raise FileNotFoundError("Expected 'safe' and 'unsafe' folders not found inside 'data/Data'. Please check the Data.zip structure.")
print(f"Dataset path found: {dataset_path}")

# Step 3: Define Custom Dataset Class
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        for label, folder in enumerate(["safe", "unsafe"]):  # 0 = safe, 1 = unsafe
            folder_path = os.path.join(root_dir, folder)
            if not os.path.exists(folder_path):
                raise FileNotFoundError(f"Folder '{folder}' not found in '{root_dir}'")
            for filename in os.listdir(folder_path):
                if filename.endswith((".jpg", ".jpeg", ".png")):  # Only load image files
                    self.image_paths.append(os.path.join(folder_path, filename))
                    self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Step 4: Data Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 pixels
    transforms.ToTensor(),         # Convert to tensor
    transforms.Normalize((0.5,), (0.5,)),  # Normalize image values
])

# Step 5: Load the Dataset
train_dataset = CustomImageDataset(dataset_path, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Debugging: Check the dataset
print(f"Number of samples in the dataset: {len(train_dataset)}")
if len(train_dataset) == 0:
    raise ValueError("The dataset is empty. Please check if 'Data.zip' is correctly unzipped and contains 'safe' and 'unsafe' folders.")

# Step 6: Load Pretrained CLIP Model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Add a Classification Layer
model.classifier = nn.Linear(model.visual_projection.out_features, 2)  # 2 classes: safe, unsafe

# Define Optimizer and Loss Function
optimizer = optim.Adam(model.classifier.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Step 7: Fine-Tune the Model
model.train()
for epoch in range(3):  # Number of epochs
    total_loss = 0
    for images, labels in train_loader:
        optimizer.zero_grad()
        images = torch.stack([img.to(torch.float32) for img in images])  # Batch of images
        outputs = model.get_image_features(pixel_values=images)
        logits = model.classifier(outputs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")

# Save the Fine-Tuned Model
model.save_pretrained("fine-tuned-model")
processor.save_pretrained("fine-tuned-model")
print("Model fine-tuned and saved successfully.")

# Step 8: Define Gradio Inference Function
def classify_image(image, class_names):
    # Load Fine-Tuned Model
    model = CLIPModel.from_pretrained("fine-tuned-model")
    processor = CLIPProcessor.from_pretrained("fine-tuned-model")
    
    # Split class names from comma-separated input
    labels = [label.strip() for label in class_names.split(",") if label.strip()]
    if not labels:
        return {"Error": "Please enter at least one valid class name."}

    # Process the image and labels
    inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)  # Convert logits to probabilities

    # Extract labels with their corresponding probabilities
    result = {label: probs[0][i].item() for i, label in enumerate(labels)}
    return dict(sorted(result.items(), key=lambda item: item[1], reverse=True))

# Step 9: Set Up Gradio Interface
iface = gr.Interface(
    fn=classify_image,
    inputs=[
        gr.Image(type="pil"),
        gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe")
    ],
    outputs=gr.Label(num_top_classes=2),
    title="Content Safety Classification",
    description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.",
)

# Launch Gradio Interface
if __name__ == "__main__":
    iface.launch()