To build the model I used Resnet18 for image part and Turkish-DistillBert for text part. Turkish-DistillBert: dbmdz/distilbert-base-turkish-cased

You can get more information (and code 🎉) on how to train or use the model on my github.

How to use the model?

In order to use the model use can use the class in model.py like the example below:

from model import Net
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
from PIL import Image
from transformers import AutoTokenizer, AutoModel

model = Net()
# If you use model on cpu you need the map_location part
model.load_state_dict(torch.load("clip_model.pt", map_location=torch.device('cpu')))
model.eval()

tokenizer = AutoTokenizer.from_pretrained("dbmdz/distilbert-base-turkish-cased")

transform=transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
],
)

def predict(img,text_vec):
  input = transform(img).unsqueeze(0)
  token_list = tokenizer(text_vec,padding = True)

  text = torch.Tensor(token_list["input_ids"]).long()
  mask = torch.Tensor(token_list["attention_mask"]).long()


  image_vec, text_vec = model(input, text , mask)
  print(F.softmax(torch.matmul(image_vec,text_vec.T),dim=1))

img = Image.open("dog.png") # A dog image

text_vec = ["Çimenler içinde bir köpek.","Bir köpek.","Çimenler içinde bir kuş."] # Descriptions
predict(img,text_vec) # Probabilities for each description
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .