Maverick98's picture
Update app.py
699133d verified
raw
history blame
5.93 kB
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import torch
import json
import requests
from PIL import Image
from torchvision import transforms
import urllib.request
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from transformers import AutoModel, AutoTokenizer
from torchvision import models, transforms
from torch.cuda.amp import GradScaler, autocast
import numpy as np
# --- Define the Model ---
class FineGrainedClassifier(nn.Module):
def __init__(self, num_classes=434): # Updated to 434 classes
super(FineGrainedClassifier, self).__init__()
self.image_encoder = models.resnet50(pretrained=True)
self.image_encoder.fc = nn.Identity()
self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en')
self.classifier = nn.Sequential(
nn.Linear(2048 + 768, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes) # Updated to 434 classes
)
def forward(self, image, input_ids, attention_mask):
image_features = self.image_encoder(image)
text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
text_features = text_output.last_hidden_state[:, 0, :]
combined_features = torch.cat((image_features, text_features), dim=1)
output = self.classifier(combined_features)
return output
# --- Data Augmentation Setup ---
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# def load_model_checkpoint(model, checkpoint_path, device):
# checkpoint = torch.load(checkpoint_path, map_location=device)
# # Strip the "module." prefix from the keys in the state_dict if they exist
# state_dict = checkpoint['model_state_dict']
# new_state_dict = {}
# for k, v in state_dict.items():
# if k.startswith("module."):
# new_state_dict[k[7:]] = v # Remove "module." prefix
# else:
# new_state_dict[k] = v
# model.load_state_dict(new_state_dict)
# return model
# Load the label-to-class mapping from your Hugging Face repository
label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
label_to_class = requests.get(label_map_url).json()
# Load your custom model from Hugging Face
model = FineGrainedClassifier(num_classes=len(label_to_class))
model_checkpoint = "Maverick98/EcommerceClassifier"
model.load_state_dict(torch.hub.load_state_dict_from_url(f"https://huggingface.co/{model_checkpoint}/resolve/main/model_checkpoint.pth", map_location=torch.device('cpu')))
# Load the tokenizer from Jina
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
# # Define image preprocessing
# transform = transforms.Compose([
# transforms.Resize((224, 224)),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])
def load_image(image_path_or_url):
"""
Load an image from a URL or local path and preprocess it.
"""
if image_path_or_url.startswith("http"):
with urllib.request.urlopen(image_path_or_url) as url:
image = Image.open(url).convert('RGB')
else:
image = Image.open(image_path_or_url).convert('RGB')
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
return image
def predict(image_path_or_url, title, threshold=0.7):
"""
Predict the top 3 categories for the given image and title.
Includes "Others" if the confidence of the top prediction is below the threshold.
"""
# Preprocess the image
image = load_image(image_path_or_url)
# Tokenize the title
title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt')
input_ids = title_encoding['input_ids']
attention_mask = title_encoding['attention_mask']
# Predict
model.eval()
with torch.no_grad():
output = model(image, input_ids=input_ids, attention_mask=attention_mask)
probabilities = torch.nn.functional.softmax(output, dim=1)
top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1)
# Map the top 3 indices to class names
top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]]
# Check if the highest probability is below the threshold
if top3_probabilities[0][0].item() < threshold:
top3_classes.insert(0, "Others")
top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1)
# Prepare the output as a dictionary
results = {}
for i in range(len(top3_classes)):
results[top3_classes[i]] = top3_probabilities[0][i].item()
return results
# Define the Gradio interface
title_input = gr.inputs.Textbox(label="Product Title", placeholder="Enter the product title here...")
image_input = gr.inputs.Textbox(label="Image URL or Path", placeholder="Enter image URL or local path here...")
output = gr.outputs.JSON(label="Top 3 Predictions with Probabilities")
gr.Interface(
fn=predict,
inputs=[image_input, title_input],
outputs=output,
title="Ecommerce Classifier",
description="This model classifies ecommerce products into one of 434 categories. If the model is unsure, it outputs 'Others'.",
).launch()