CLIP Vision Model With Projection
This repository provides a custom vision-only/image-encoder based CLIP model that wraps a CLIPVisionModel with a trainable projection layer, a self-attention–based post-processing block, and gating for multi-image inputs. The student model is distilled from two teacher models (CLIP and FLAVA). During training, it receives two clusters of images (e.g., 10 in one cluster and 30 in the other). Instead of enforcing an equal selection from each cluster, the student applies a learnable gating mechanism that assigns a soft probability score to each image, effectively deciding how many images to select from each cluster. These selected images (weighted by their gating probabilities) are then combined into a single embedding, which the model uses to compare cluster similarity. A reinforcement learning reward encourages the student to produce embeddings that bring images from the same cluster closer together while separating images from different clusters. The strategy is letting the model sample some images from two clusters in a way to facilitate maximisation of the reward, which is tied to the objective of making the clusters similar or disimilar. This facilitates unsupervised learning in scenarios where:
- Contrastive loss design is tricky because of presence of different views of a same place
- Data and labels are noisy.
- A group of images give the complete picture not a single image.
Overview
Teacher Model
A frozen CLIPVisionModel (from transformers) and a frozen Flava (from transformers), meaning the base vision parameters do not update.
Student Model
A 310M param model which learns from both the teacher models.The student model decides which teacher it focuses more on based on the reward, given a teacher's policy. (Initial pretraining is done to make both the teachers' embedding similar without losss in their individual classification performances.)
Trainable Projection
A linear layer mapping the hidden size (config.hidden_size) to a new dimension (projection_dim).
PostProjection
A multi-head self-attention block that refines the projected embeddings further.
Gating Layer
When multiple images per sample are provided, a small gating network learns how to combine them into a single embedding via a weighted sum.
Final Non-Linear Projection
A feed-forward network on top of the post-projection embeddings.
Installation & Usage
pip3 install torch transformers Pillow pandas numpy
Model Loading
from transformers import CLIPProcessor, AutoConfig , AutoModel
import torch
# Detect device
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Using MPS device (Apple GPU)")
elif torch.cuda.is_available():
device = torch.device("cuda")
print("Using CUDA device")
else:
device = torch.device("cpu")
print("Using CPU")
# 1) Load the model and processor
config = AutoConfig.from_pretrained("paytm/StoreClip")
processor = CLIPProcessor.from_pretrained("paytm/StoreClip")
model = AutoModel.from_pretrained("paytm/StoreClip", config=config,trust_remote_code=True).to(device)
Example Usage
# Suppose we have a Python list of image_collections, each with its own list-of-images
# Example:
# image_groups = [
# [PIL_imgA1, PIL_imgA2], # image_collection A
# [PIL_imgB1], # image_collection B
# [PIL_imgC1, PIL_imgC2], # image_collection C
# ...
# ]
from pathlib import Path
import os
from PIL import Image
# 1) Load image lists from dirs
# Function to load images from directory
def load_images_from_directory(directory):
image_files = []
images = []
for file in os.listdir(directory):
if file.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(directory, file)
try:
image = Image.open(image_path).convert('RGB')
images.append(image)
image_files.append(file)
except Exception as e:
print(f"Error loading {file}: {e}")
return images, image_files
data_dir = Path(path_to_dir)
directories = [d for d in data_dir.iterdir() if d.is_dir()]
image_groups = []
image_paths = []
for dir in directories:
images,image_path=load_images_from_directory(dir)
image_groups.append(images)
image_paths.append(image_path)
# 2) Compute how many images each image_collection has
counts = [len(image_list) for image_list in image_groups]
# 3) Flatten all images across all image_collection
all_images = []
for image_list in image_groups:
all_images.extend(image_list)
inputs = processor(images=all_images, return_tensors="pt", padding=True)
# 4) Now pass pixel_values + counts directly to the model
pixel_values = inputs["pixel_values"]
pixel_values = pixel_values.to(device) # if using GPU
model.eval()
with torch.no_grad():
outputs = model(
pixel_values=pixel_values,
counts=counts,
normalize=True
)
# outputs.image_embeds => shape [len(image_groups), projection_dim]
# i.e. one embedding per image_collection
collection_embeds = outputs.image_embeds
# outputs.teacher_embeds => shape [len(images), projection_dim]
# i.e. one embedding per image
individual_embeds = outputs.teacher_embeds
- Downloads last month
- 0