Maverick98 commited on
Commit
065e4c1
·
verified ·
1 Parent(s): c56d5d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -8
app.py CHANGED
@@ -6,22 +6,95 @@ import requests
6
  from PIL import Image
7
  from torchvision import transforms
8
  import urllib.request
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Load the label-to-class mapping from your Hugging Face repository
11
- label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
12
- label_to_class = requests.get(label_map_url).json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Load the model and tokenizer from your Hugging Face repository
15
- model = AutoModel.from_pretrained("Maverick98/EcommerceClassifier")
16
- tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
17
 
18
- # Define image preprocessing
19
  transform = transforms.Compose([
20
  transforms.Resize((224, 224)),
 
 
 
21
  transforms.ToTensor(),
22
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
23
  ])
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def load_image(image_path_or_url):
26
  """
27
  Load an image from a URL or local path and preprocess it.
@@ -45,7 +118,7 @@ def predict(image_path_or_url, title, threshold=0.7):
45
  image = load_image(image_path_or_url)
46
 
47
  # Tokenize the title
48
- title_encoding = tokenizer(title, padding='max_length', max_length=32, truncation=True, return_tensors='pt')
49
  input_ids = title_encoding['input_ids']
50
  attention_mask = title_encoding['attention_mask']
51
 
 
6
  from PIL import Image
7
  from torchvision import transforms
8
  import urllib.request
9
+ import os
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from torch.utils.data import DataLoader, Dataset, DistributedSampler
14
+ from transformers import AutoModel, AutoTokenizer
15
+ from torchvision import models, transforms
16
+ from sklearn.model_selection import train_test_split
17
+ from sklearn.utils.class_weight import compute_class_weight
18
+ from sklearn.metrics import precision_recall_fscore_support, accuracy_score
19
+ from torch.cuda.amp import GradScaler, autocast
20
+ import numpy as np
21
+ import torch.multiprocessing as mp
22
+ import torch.distributed as dist
23
+ import matplotlib.pyplot as plt
24
 
25
+ # --- Define the Model ---
26
+ class FineGrainedClassifier(nn.Module):
27
+ def __init__(self, num_classes=434): # Updated to 434 classes
28
+ super(FineGrainedClassifier, self).__init__()
29
+ self.image_encoder = models.resnet50(pretrained=True)
30
+ self.image_encoder.fc = nn.Identity()
31
+ self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en')
32
+ self.classifier = nn.Sequential(
33
+ nn.Linear(2048 + 768, 1024),
34
+ nn.BatchNorm1d(1024),
35
+ nn.ReLU(),
36
+ nn.Dropout(0.3),
37
+ nn.Linear(1024, 512),
38
+ nn.BatchNorm1d(512),
39
+ nn.ReLU(),
40
+ nn.Dropout(0.3),
41
+ nn.Linear(512, num_classes) # Updated to 434 classes
42
+ )
43
+
44
+ def forward(self, image, input_ids, attention_mask):
45
+ image_features = self.image_encoder(image)
46
+ text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
47
+ text_features = text_output.last_hidden_state[:, 0, :]
48
+ combined_features = torch.cat((image_features, text_features), dim=1)
49
+ output = self.classifier(combined_features)
50
+ return output
51
 
 
 
 
52
 
53
+ # --- Data Augmentation Setup ---
54
  transform = transforms.Compose([
55
  transforms.Resize((224, 224)),
56
+ transforms.RandomHorizontalFlip(),
57
+ transforms.RandomRotation(15),
58
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
59
  transforms.ToTensor(),
60
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
61
  ])
62
 
63
+ # def load_model_checkpoint(model, checkpoint_path, device):
64
+ # checkpoint = torch.load(checkpoint_path, map_location=device)
65
+
66
+ # # Strip the "module." prefix from the keys in the state_dict if they exist
67
+ # state_dict = checkpoint['model_state_dict']
68
+ # new_state_dict = {}
69
+
70
+ # for k, v in state_dict.items():
71
+ # if k.startswith("module."):
72
+ # new_state_dict[k[7:]] = v # Remove "module." prefix
73
+ # else:
74
+ # new_state_dict[k] = v
75
+
76
+ # model.load_state_dict(new_state_dict)
77
+ # return model
78
+
79
+ # Load the label-to-class mapping from your Hugging Face repository
80
+ label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
81
+ label_to_class = requests.get(label_map_url).json()
82
+
83
+
84
+ # Load your custom model from Hugging Face
85
+ model = FineGrainedClassifier(num_classes=len(label_to_class))
86
+ model_checkpoint = "Maverick98/EcommerceClassifier"
87
+ model.load_state_dict(torch.hub.load_state_dict_from_url(f"https://huggingface.co/{model_checkpoint}/resolve/main/model_checkpoint.pth", map_location=torch.device('cpu')))
88
+ # Load the tokenizer from Jina
89
+ tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
90
+
91
+ # # Define image preprocessing
92
+ # transform = transforms.Compose([
93
+ # transforms.Resize((224, 224)),
94
+ # transforms.ToTensor(),
95
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
96
+ # ])
97
+
98
  def load_image(image_path_or_url):
99
  """
100
  Load an image from a URL or local path and preprocess it.
 
118
  image = load_image(image_path_or_url)
119
 
120
  # Tokenize the title
121
+ title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt')
122
  input_ids = title_encoding['input_ids']
123
  attention_mask = title_encoding['attention_mask']
124