Maverick98 commited on
Commit
0bca9cd
·
verified ·
1 Parent(s): 64ffe28
Files changed (1) hide show
  1. app.py +16 -41
app.py CHANGED
@@ -6,15 +6,8 @@ import requests
6
  from PIL import Image
7
  from torchvision import transforms
8
  import urllib.request
9
- import os
10
- import torch
11
  import torch.nn as nn
12
- import torch.optim as optim
13
- from torch.utils.data import DataLoader, Dataset, DistributedSampler
14
- from transformers import AutoModel, AutoTokenizer
15
- from torchvision import models, transforms
16
- from torch.cuda.amp import GradScaler, autocast
17
- import numpy as np
18
 
19
  # --- Define the Model ---
20
  class FineGrainedClassifier(nn.Module):
@@ -43,7 +36,6 @@ class FineGrainedClassifier(nn.Module):
43
  output = self.classifier(combined_features)
44
  return output
45
 
46
-
47
  # --- Data Augmentation Setup ---
48
  transform = transforms.Compose([
49
  transforms.Resize((224, 224)),
@@ -54,60 +46,43 @@ transform = transforms.Compose([
54
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
55
  ])
56
 
57
- # def load_model_checkpoint(model, checkpoint_path, device):
58
- # checkpoint = torch.load(checkpoint_path, map_location=device)
59
-
60
- # # Strip the "module." prefix from the keys in the state_dict if they exist
61
- # state_dict = checkpoint['model_state_dict']
62
- # new_state_dict = {}
63
-
64
- # for k, v in state_dict.items():
65
- # if k.startswith("module."):
66
- # new_state_dict[k[7:]] = v # Remove "module." prefix
67
- # else:
68
- # new_state_dict[k] = v
69
-
70
- # model.load_state_dict(new_state_dict)
71
- # return model
72
-
73
  # Load the label-to-class mapping from your Hugging Face repository
74
  label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
75
  label_to_class = requests.get(label_map_url).json()
76
 
77
-
78
  # Load your custom model from Hugging Face
79
  model = FineGrainedClassifier(num_classes=len(label_to_class))
80
- model_checkpoint = "Maverick98/EcommerceClassifier"
81
  checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth"
82
  checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))
83
- # Extract and load the model state_dict
84
- model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
 
 
 
 
85
 
86
  # Load the tokenizer from Jina
87
  tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
88
 
89
-
90
- def load_image(image_path_or_url):
91
  """
92
- Load an image from a URL or local path and preprocess it.
93
  """
94
- if image_path_or_url.startswith("http"):
95
- with urllib.request.urlopen(image_path_or_url) as url:
96
- image = Image.open(url).convert('RGB')
97
- else:
98
- image = Image.open(image_path_or_url).convert('RGB')
99
-
100
  image = transform(image)
101
  image = image.unsqueeze(0) # Add batch dimension
102
  return image
103
 
104
- def predict(image_path_or_url, title, threshold=0.7):
105
  """
106
  Predict the top 3 categories for the given image and title.
107
  Includes "Others" if the confidence of the top prediction is below the threshold.
108
  """
109
  # Preprocess the image
110
- image = load_image(image_path_or_url)
111
 
112
  # Tokenize the title
113
  title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt')
@@ -138,7 +113,7 @@ def predict(image_path_or_url, title, threshold=0.7):
138
 
139
  # Define the Gradio interface
140
  title_input = gr.inputs.Textbox(label="Product Title", placeholder="Enter the product title here...")
141
- image_input = gr.inputs.Textbox(label="Image URL or Path", placeholder="Enter image URL or local path here...")
142
  output = gr.outputs.JSON(label="Top 3 Predictions with Probabilities")
143
 
144
  gr.Interface(
 
6
  from PIL import Image
7
  from torchvision import transforms
8
  import urllib.request
9
+ from torchvision import models
 
10
  import torch.nn as nn
 
 
 
 
 
 
11
 
12
  # --- Define the Model ---
13
  class FineGrainedClassifier(nn.Module):
 
36
  output = self.classifier(combined_features)
37
  return output
38
 
 
39
  # --- Data Augmentation Setup ---
40
  transform = transforms.Compose([
41
  transforms.Resize((224, 224)),
 
46
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
47
  ])
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Load the label-to-class mapping from your Hugging Face repository
50
  label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
51
  label_to_class = requests.get(label_map_url).json()
52
 
 
53
  # Load your custom model from Hugging Face
54
  model = FineGrainedClassifier(num_classes=len(label_to_class))
 
55
  checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth"
56
  checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))
57
+
58
+ # Strip the "module." prefix from the keys in the state_dict if they exist
59
+ new_state_dict = {}
60
+ for k, v in checkpoint.items():
61
+ if k.startswith("module."):
62
+ new_state_dict[k[7:]] = v # Remove "module." prefix
63
+ else:
64
+ new_state_dict[k] = v
65
+
66
+ model.load_state_dict(new_state_dict)
67
 
68
  # Load the tokenizer from Jina
69
  tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
70
 
71
+ def load_image(image):
 
72
  """
73
+ Preprocess the uploaded image.
74
  """
 
 
 
 
 
 
75
  image = transform(image)
76
  image = image.unsqueeze(0) # Add batch dimension
77
  return image
78
 
79
+ def predict(image, title, threshold=0.7):
80
  """
81
  Predict the top 3 categories for the given image and title.
82
  Includes "Others" if the confidence of the top prediction is below the threshold.
83
  """
84
  # Preprocess the image
85
+ image = load_image(image)
86
 
87
  # Tokenize the title
88
  title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt')
 
113
 
114
  # Define the Gradio interface
115
  title_input = gr.inputs.Textbox(label="Product Title", placeholder="Enter the product title here...")
116
+ image_input = gr.inputs.Image(type="pil", label="Upload Image")
117
  output = gr.outputs.JSON(label="Top 3 Predictions with Probabilities")
118
 
119
  gr.Interface(