tinyvvision / README.md
ProCreations's picture
Update README.md
534448c verified
|
raw
history blame
6.18 kB
metadata
license: mit
language:
  - en
pipeline_tag: zero-shot-image-classification
tags:
  - vision
  - simple
  - small

tinyvvision 🧠✨

tinyvvision is a compact, synthetic curriculum-trained vision-language model designed to demonstrate real zero-shot capability in a minimal setup. Despite its small size (~630k parameters), it aligns images and captions effectively by learning shared visual-language embeddings.

What tinyvvision can do:

  • Match simple geometric shapes (circles, stars, hearts, triangles, etc.) and descriptive captions (e.g., "a red circle", "a yellow star").
  • Perform genuine zero-shot generalization, meaning it can correctly match captions to shapes and colors it has never explicitly encountered during training.

Model Details:

  • Type: Contrastive embedding (CLIP-style, zero-shot)
  • Parameters: ~630,000 (tiny!)
  • Training data: Fully synthetic—randomly generated shapes, letters, numbers, and symbols paired with descriptive text captions.
  • Architecture:
    • Image Encoder: Simple CNN
    • Text Encoder: Small embedding layer + bidirectional GRU
  • Embedding Dim: 128-dimensional shared embedding space

Examples of Zero-Shot Matching:

  • Seen during training: "a red circle" → correctly matches the drawn red circle.
  • Never seen: "a teal lightning bolt" → correctly matched a hand-drawn lightning bolt shape, despite never having seen one during training.

Limitations:

  • tinyvvision is designed as a demonstration of zero-shot embedding and generalization on synthetic data. It is not trained on real-world data or complex scenarios. While robust within its domain (simple geometric shapes and clear captions), results may vary significantly on more complicated or out-of-domain inputs.

How to Test tinyvvision:

Check out the provided inference script to easily test your own shapes and captions. Feel free to challenge tinyvvision with new, unseen combinations to explore its generalization capability!

import torch, re, numpy as np, math
from PIL import Image, ImageDraw, ImageFont

repo = "ProCreations/tinyvvision"
pth = hf_hub_download(repo, "cortexclip-mini.pth")
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
state = torch.load(pth, map_location=device)
idx2tok = state["vocab"]
tok2idx = {t:i for i,t in enumerate(idx2tok)}
def encode_txt(s, maxlen=16):
    toks = re.findall(r"\w+|[^\w\s]", s.lower())
    ids = [tok2idx.get(t,0) for t in toks][:maxlen]
    return ids + [0]*(maxlen-len(ids))
class TE(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = torch.nn.Embedding(len(idx2tok), 64)
        self.gru = torch.nn.GRU(64, 128, num_layers=2, bidirectional=True, batch_first=True)
        self.out_proj = torch.nn.Linear(256, 128)
    def forward(self, x):
        e, _ = self.gru(self.emb(x))
        return self.out_proj(e[:, -1])
class IE(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(3,32,5,1,2), torch.nn.ReLU(),
            torch.nn.Conv2d(32,64,3,1,1), torch.nn.ReLU(),
            torch.nn.Conv2d(64,128,3,1,1), torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool2d((4,4)), torch.nn.Flatten(),
            torch.nn.Linear(128*4*4,128), torch.nn.ReLU()
        )
    def forward(self, x): return self.conv(x)
te, ie = TE().to(device), IE().to(device)
te.load_state_dict(state["text_encoder"])
ie.load_state_dict(state["image_encoder"])
te.eval(); ie.eval()

# ----- CUSTOMIZE YOUR EXAMPLES HERE -----
# To try your own image:
#   1. Replace the 'custom_image()' function with your image drawing/loading code.
#   2. Replace 'custom_caption' with your own caption for the image.
def custom_image():
    # Example: Draw your own "blue hexagon" shape below!
    img = Image.new("RGB",(64,64),"white")
    dr = ImageDraw.Draw(img)
    dr.regular_polygon((32,32,22), n_sides=6, fill="blue")
    arr = np.array(img).astype(np.float32)/255.0
    return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
custom_caption = "a blue hexagon"

# ----- FUN DEMO EXAMPLES -----
def draw_red_heart():
    img = Image.new("RGB",(64,64),"white")
    dr = ImageDraw.Draw(img)
    dr.polygon([(32,18),(50,34),(32,56),(14,34)], fill="red")  # simple heart
    dr.ellipse((18,12,32,32), fill="red")
    dr.ellipse((32,12,46,32), fill="red")
    arr = np.array(img).astype(np.float32)/255.0
    return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
def draw_purple_star():
    img = Image.new("RGB",(64,64),"white")
    dr = ImageDraw.Draw(img)
    points = [ (32+20*math.cos(math.radians(a)),32+20*math.sin(math.radians(a))) for a in range(-90, 270, 72) ]
    for i in range(5):
        dr.line([points[i], points[(i+2)%5]], fill="purple", width=7)
    arr = np.array(img).astype(np.float32)/255.0
    return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
def draw_orange_pentagon():
    img = Image.new("RGB",(64,64),"white")
    dr = ImageDraw.Draw(img)
    dr.regular_polygon((32,32,22), n_sides=5, fill="orange")
    arr = np.array(img).astype(np.float32)/255.0
    return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)

demo_imgs = [
    (custom_image(), custom_caption),
    (draw_red_heart(), "a red heart"),
    (draw_purple_star(), "a purple star"),
    (draw_orange_pentagon(), "an orange pentagon"),
]
captions = [c for (_,c) in demo_imgs]
img_tensors = [im for (im,_) in demo_imgs]
cap_ids  = torch.tensor([encode_txt(c) for c in captions], device=device)

with torch.no_grad():
    txt_emb = te(cap_ids)
    for i, (img, caption) in enumerate(zip(img_tensors, captions)):
        im_emb = ie(img)
        sim = torch.nn.functional.cosine_similarity(im_emb, txt_emb).cpu().numpy()
        rank = int(np.argmax(sim))
        print(f"Input image {i+1}: '{caption}'")
        print("  Similarity scores:")
        for j, c in enumerate(captions):
            print(f"    {c}: {sim[j]:.4f}")
        print("  Best match:", captions[rank], "\n")```
✨ **Enjoy experimenting!** ✨