File size: 4,835 Bytes
2294c6e
 
 
 
 
 
 
 
0bca9cd
065e4c1
2294c6e
065e4c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2294c6e
065e4c1
2294c6e
 
065e4c1
 
 
2294c6e
 
 
 
065e4c1
 
 
 
 
 
64ffe28
 
0bca9cd
 
 
 
 
 
 
 
 
 
64ffe28
065e4c1
 
 
0bca9cd
2294c6e
0bca9cd
2294c6e
 
 
 
 
0bca9cd
2294c6e
 
 
 
 
0bca9cd
2294c6e
 
065e4c1
2294c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bca9cd
2294c6e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import torch
import json
import requests
from PIL import Image
from torchvision import transforms
import urllib.request
from torchvision import models
import torch.nn as nn

# --- Define the Model ---
class FineGrainedClassifier(nn.Module):
    def __init__(self, num_classes=434):  # Updated to 434 classes
        super(FineGrainedClassifier, self).__init__()
        self.image_encoder = models.resnet50(pretrained=True)
        self.image_encoder.fc = nn.Identity()
        self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en')
        self.classifier = nn.Sequential(
            nn.Linear(2048 + 768, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)  # Updated to 434 classes
        )
    
    def forward(self, image, input_ids, attention_mask):
        image_features = self.image_encoder(image)
        text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_output.last_hidden_state[:, 0, :]
        combined_features = torch.cat((image_features, text_features), dim=1)
        output = self.classifier(combined_features)
        return output

# --- Data Augmentation Setup ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the label-to-class mapping from your Hugging Face repository
label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
label_to_class = requests.get(label_map_url).json()

# Load your custom model from Hugging Face
model = FineGrainedClassifier(num_classes=len(label_to_class))
checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth"
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))

# Strip the "module." prefix from the keys in the state_dict if they exist
new_state_dict = {}
for k, v in checkpoint.items():
    if k.startswith("module."):
        new_state_dict[k[7:]] = v  # Remove "module." prefix
    else:
        new_state_dict[k] = v

model.load_state_dict(new_state_dict)

# Load the tokenizer from Jina
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")

def load_image(image):
    """
    Preprocess the uploaded image.
    """
    image = transform(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image

def predict(image, title, threshold=0.7):
    """
    Predict the top 3 categories for the given image and title.
    Includes "Others" if the confidence of the top prediction is below the threshold.
    """
    # Preprocess the image
    image = load_image(image)
    
    # Tokenize the title
    title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt')
    input_ids = title_encoding['input_ids']
    attention_mask = title_encoding['attention_mask']

    # Predict
    model.eval()
    with torch.no_grad():
        output = model(image, input_ids=input_ids, attention_mask=attention_mask)
        probabilities = torch.nn.functional.softmax(output, dim=1)
        top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1)

    # Map the top 3 indices to class names
    top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]]

    # Check if the highest probability is below the threshold
    if top3_probabilities[0][0].item() < threshold:
        top3_classes.insert(0, "Others")
        top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1)

    # Prepare the output as a dictionary
    results = {}
    for i in range(len(top3_classes)):
        results[top3_classes[i]] = top3_probabilities[0][i].item()
    
    return results

# Define the Gradio interface
title_input = gr.inputs.Textbox(label="Product Title", placeholder="Enter the product title here...")
image_input = gr.inputs.Image(type="pil", label="Upload Image")
output = gr.outputs.JSON(label="Top 3 Predictions with Probabilities")

gr.Interface(
    fn=predict, 
    inputs=[image_input, title_input], 
    outputs=output,
    title="Ecommerce Classifier",
    description="This model classifies ecommerce products into one of 434 categories. If the model is unsure, it outputs 'Others'.",
).launch()