import gradio as gr from transformers import AutoModel, AutoTokenizer import torch import json import requests from PIL import Image from torchvision import transforms import urllib.request import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset, DistributedSampler from transformers import AutoModel, AutoTokenizer from torchvision import models, transforms from torch.cuda.amp import GradScaler, autocast import numpy as np # --- Define the Model --- class FineGrainedClassifier(nn.Module): def __init__(self, num_classes=434): # Updated to 434 classes super(FineGrainedClassifier, self).__init__() self.image_encoder = models.resnet50(pretrained=True) self.image_encoder.fc = nn.Identity() self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en') self.classifier = nn.Sequential( nn.Linear(2048 + 768, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, num_classes) # Updated to 434 classes ) def forward(self, image, input_ids, attention_mask): image_features = self.image_encoder(image) text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) text_features = text_output.last_hidden_state[:, 0, :] combined_features = torch.cat((image_features, text_features), dim=1) output = self.classifier(combined_features) return output # --- Data Augmentation Setup --- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # def load_model_checkpoint(model, checkpoint_path, device): # checkpoint = torch.load(checkpoint_path, map_location=device) # # Strip the "module." prefix from the keys in the state_dict if they exist # state_dict = checkpoint['model_state_dict'] # new_state_dict = {} # for k, v in state_dict.items(): # if k.startswith("module."): # new_state_dict[k[7:]] = v # Remove "module." prefix # else: # new_state_dict[k] = v # model.load_state_dict(new_state_dict) # return model # Load the label-to-class mapping from your Hugging Face repository label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json" label_to_class = requests.get(label_map_url).json() # Load your custom model from Hugging Face model = FineGrainedClassifier(num_classes=len(label_to_class)) model_checkpoint = "Maverick98/EcommerceClassifier" model.load_state_dict(torch.hub.load_state_dict_from_url(f"https://huggingface.co/{model_checkpoint}/resolve/main/model_checkpoint.pth", map_location=torch.device('cpu'))) # Load the tokenizer from Jina tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en") # # Define image preprocessing # transform = transforms.Compose([ # transforms.Resize((224, 224)), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ]) def load_image(image_path_or_url): """ Load an image from a URL or local path and preprocess it. """ if image_path_or_url.startswith("http"): with urllib.request.urlopen(image_path_or_url) as url: image = Image.open(url).convert('RGB') else: image = Image.open(image_path_or_url).convert('RGB') image = transform(image) image = image.unsqueeze(0) # Add batch dimension return image def predict(image_path_or_url, title, threshold=0.7): """ Predict the top 3 categories for the given image and title. Includes "Others" if the confidence of the top prediction is below the threshold. """ # Preprocess the image image = load_image(image_path_or_url) # Tokenize the title title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt') input_ids = title_encoding['input_ids'] attention_mask = title_encoding['attention_mask'] # Predict model.eval() with torch.no_grad(): output = model(image, input_ids=input_ids, attention_mask=attention_mask) probabilities = torch.nn.functional.softmax(output, dim=1) top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1) # Map the top 3 indices to class names top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]] # Check if the highest probability is below the threshold if top3_probabilities[0][0].item() < threshold: top3_classes.insert(0, "Others") top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1) # Prepare the output as a dictionary results = {} for i in range(len(top3_classes)): results[top3_classes[i]] = top3_probabilities[0][i].item() return results # Define the Gradio interface title_input = gr.inputs.Textbox(label="Product Title", placeholder="Enter the product title here...") image_input = gr.inputs.Textbox(label="Image URL or Path", placeholder="Enter image URL or local path here...") output = gr.outputs.JSON(label="Top 3 Predictions with Probabilities") gr.Interface( fn=predict, inputs=[image_input, title_input], outputs=output, title="Ecommerce Classifier", description="This model classifies ecommerce products into one of 434 categories. If the model is unsure, it outputs 'Others'.", ).launch()