|
--- |
|
license: mit |
|
pipeline_tag: zero-shot-image-classification |
|
--- |
|
# **CLIP Vision Model With Projection** |
|
This repository provides a custom vision-only/image-encoder based CLIP model that wraps a CLIPVisionModel with a trainable projection layer, a self-attention–based post-processing block, and gating for multi-image inputs. |
|
The student model is distilled from two teacher models (CLIP and FLAVA). During training, it receives two clusters of images (e.g., 10 in one cluster and 30 in the other). |
|
Instead of enforcing an equal selection from each cluster, the student applies a learnable gating mechanism that assigns a soft probability score to each image, effectively deciding how many images to select from each cluster. |
|
These selected images (weighted by their gating probabilities) are then combined into a single embedding, which the model uses to compare cluster similarity. |
|
A reinforcement learning reward encourages the student to produce embeddings that bring images from the same cluster closer together while separating images from different clusters. |
|
The strategy is letting the model sample some images from two clusters in a way to facilitate maximisation of the reward, which is tied to the objective of making the clusters similar or disimilar. |
|
This facilitates unsupervised learning in scenarios where: |
|
1. Contrastive loss design is tricky because of presence of different views of a same place |
|
2. Data and labels are noisy. |
|
3. A group of images give the complete picture not a single image. |
|
|
|
## **Overview** |
|
### **Teacher Model** |
|
A frozen CLIPVisionModel (from transformers) and a frozen Flava (from transformers), meaning the base vision parameters do not update. |
|
### **Student Model** |
|
A 310M param model which learns from both the teacher models.The student model decides which teacher it focuses more on based on the reward, given a teacher's policy. (Initial pretraining is done to make both the teachers' embedding similar without losss in their individual classification performances.) |
|
### **Trainable Projection** |
|
A linear layer mapping the hidden size (config.hidden_size) to a new dimension (projection_dim). |
|
### **PostProjection** |
|
A multi-head self-attention block that refines the projected embeddings further. |
|
### **Gating Layer** |
|
When multiple images per sample are provided, a small gating network learns how to combine them into a single embedding via a weighted sum. |
|
### **Final Non-Linear Projection** |
|
A feed-forward network on top of the post-projection embeddings. |
|
|
|
## **Installation & Usage** |
|
```python |
|
pip3 install torch transformers Pillow pandas numpy |
|
``` |
|
## **Model Loading** |
|
```python |
|
from transformers import CLIPProcessor, AutoConfig , AutoModel |
|
import torch |
|
# Detect device |
|
if torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
print("Using MPS device (Apple GPU)") |
|
elif torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
print("Using CUDA device") |
|
else: |
|
device = torch.device("cpu") |
|
print("Using CPU") |
|
|
|
# 1) Load the model and processor |
|
config = AutoConfig.from_pretrained("paytm/StoreClip") |
|
processor = CLIPProcessor.from_pretrained("paytm/StoreClip") |
|
model = AutoModel.from_pretrained("paytm/StoreClip", config=config,trust_remote_code=True).to(device) |
|
``` |
|
|
|
## **Example Usage** |
|
```python |
|
# Suppose we have a Python list of image_collections, each with its own list-of-images |
|
# Example: |
|
# image_groups = [ |
|
# [PIL_imgA1, PIL_imgA2], # image_collection A |
|
# [PIL_imgB1], # image_collection B |
|
# [PIL_imgC1, PIL_imgC2], # image_collection C |
|
# ... |
|
# ] |
|
|
|
from pathlib import Path |
|
import os |
|
from PIL import Image |
|
|
|
# 1) Load image lists from dirs |
|
|
|
# Function to load images from directory |
|
def load_images_from_directory(directory): |
|
image_files = [] |
|
images = [] |
|
for file in os.listdir(directory): |
|
if file.lower().endswith(('.png', '.jpg', '.jpeg')): |
|
image_path = os.path.join(directory, file) |
|
try: |
|
image = Image.open(image_path).convert('RGB') |
|
images.append(image) |
|
image_files.append(file) |
|
except Exception as e: |
|
print(f"Error loading {file}: {e}") |
|
return images, image_files |
|
``` |
|
|
|
```python |
|
data_dir = Path(path_to_dir) |
|
directories = [d for d in data_dir.iterdir() if d.is_dir()] |
|
image_groups = [] |
|
image_paths = [] |
|
for dir in directories: |
|
images,image_path=load_images_from_directory(dir) |
|
image_groups.append(images) |
|
image_paths.append(image_path) |
|
|
|
# 2) Compute how many images each image_collection has |
|
counts = [len(image_list) for image_list in image_groups] |
|
|
|
# 3) Flatten all images across all image_collection |
|
all_images = [] |
|
for image_list in image_groups: |
|
all_images.extend(image_list) |
|
inputs = processor(images=all_images, return_tensors="pt", padding=True) |
|
|
|
# 4) Now pass pixel_values + counts directly to the model |
|
pixel_values = inputs["pixel_values"] |
|
pixel_values = pixel_values.to(device) # if using GPU |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model( |
|
pixel_values=pixel_values, |
|
counts=counts, |
|
normalize=True |
|
) |
|
# outputs.image_embeds => shape [len(image_groups), projection_dim] |
|
# i.e. one embedding per image_collection |
|
collection_embeds = outputs.image_embeds |
|
|
|
# outputs.teacher_embeds => shape [len(images), projection_dim] |
|
# i.e. one embedding per image |
|
individual_embeds = outputs.teacher_embeds |
|
``` |