Maverick98 commited on
Commit
80f1c54
·
verified ·
1 Parent(s): 7a10cff

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +220 -3
README.md CHANGED
@@ -1,3 +1,220 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ # Model Card for EcommerceClassifier
6
+
7
+ **EcommerceClassifier** is a fine-grained product classifier specifically designed for e-commerce platforms. This model leverages both product images and titles to classify items into one of 434 categories across two primary e-commerce domains: Grocery & Gourmet and Health & Household.
8
+
9
+ ## Model Details
10
+
11
+ ### Model Description
12
+
13
+ EcommerceClassifier is a multi-modal deep learning model developed to enhance product categorization in e-commerce settings. It integrates image and text data to provide accurate classifications, ensuring that products are correctly placed in their respective categories. This model is particularly useful in automating the product categorization process, optimizing search results, and improving recommendation systems.
14
+
15
+ - **Developed by:** [Mohit Dhawan]
16
+ - **Model type:** Multi-modal classification model
17
+ - **Language(s) (NLP):** English (product titles)
18
+ - **License:** Apache 2.0
19
+ - **Finetuned from model:** ResNet50 for image encoding, Jina's embeddings for text encoding
20
+
21
+ ### Model Sources
22
+
23
+ - **Repository:** [Repository URL on Hugging Face]
24
+ - **Demo:** [Gradio Demo URL]
25
+
26
+ ## Uses
27
+
28
+ ### Direct Use
29
+
30
+ EcommerceClassifier is intended for direct use in e-commerce platforms to automate and improve the accuracy of product classification. It can be integrated into existing systems to classify new products, enhance search functionality, and improve the relevancy of recommendations.
31
+
32
+ ### Downstream Use
33
+
34
+ EcommerceClassifier can be fine-tuned for specific e-commerce categories or extended to include additional product domains. It can also be integrated into larger e-commerce systems for fraud detection, where misclassified or counterfeit products are flagged.
35
+
36
+ ### Out-of-Scope Use
37
+
38
+ EcommerceClassifier is not intended for use outside of e-commerce product classification, particularly in contexts where the input data is significantly different from the domains it was trained on. Misuse includes attempts to classify non-e-commerce-related images or texts.
39
+
40
+ ## Bias, Risks, and Limitations
41
+
42
+ While EcommerceClassifier is trained on a diverse dataset, it may still exhibit biases inherent in the training data, particularly if certain categories are underrepresented. There is also a risk of overfitting to specific visual or textual features, which may reduce its effectiveness on new, unseen data.
43
+
44
+ ### Recommendations
45
+
46
+ Users should be aware of the potential biases in the model and consider re-training or fine-tuning EcommerceClassifier with more diverse or updated data as needed. Regular evaluation of the model's performance on new data is recommended to ensure it continues to perform accurately.
47
+
48
+ ## How to Get Started with the Model
49
+
50
+ Use the code below to get started with EcommerceClassifier:
51
+
52
+ ```python
53
+ import torch
54
+ from transformers import AutoModel, AutoTokenizer
55
+ import json
56
+ import requests
57
+ from PIL import Image
58
+ from torchvision import transforms
59
+ import urllib.request
60
+
61
+ # Load the label-to-class mapping from Hugging Face
62
+ label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
63
+ label_to_class = requests.get(label_map_url).json()
64
+
65
+ # Load the model and tokenizer
66
+ model = AutoModel.from_pretrained("Maverick98/EcommerceClassifier")
67
+ tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
68
+
69
+ # Define image preprocessing
70
+ transform = transforms.Compose([
71
+ transforms.Resize((224, 224)),
72
+ transforms.ToTensor(),
73
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
74
+ ])
75
+
76
+ def load_image(image_path_or_url):
77
+ if image_path_or_url.startswith("http"):
78
+ with urllib.request.urlopen(image_path_or_url) as url:
79
+ image = Image.open(url).convert('RGB')
80
+ else:
81
+ image = Image.open(image_path_or_url).convert('RGB')
82
+
83
+ image = transform(image)
84
+ image = image.unsqueeze(0) # Add batch dimension
85
+ return image
86
+
87
+ def predict(image_path_or_url, title, threshold=0.7):
88
+ # Preprocess the image
89
+ image = load_image(image_path_or_url)
90
+
91
+ # Tokenize title
92
+ title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt')
93
+ input_ids = title_encoding['input_ids']
94
+ attention_mask = title_encoding['attention_mask']
95
+
96
+ # Predict
97
+ model.eval()
98
+ with torch.no_grad():
99
+ output = model(image, input_ids=input_ids, attention_mask=attention_mask)
100
+ probabilities = torch.nn.functional.softmax(output, dim=1)
101
+ top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1)
102
+
103
+ # Map the top 3 indices to class names
104
+ top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]]
105
+
106
+ # Check if the highest probability is below the threshold
107
+ if top3_probabilities[0][0].item() < threshold:
108
+ top3_classes.insert(0, "Others")
109
+ top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1)
110
+
111
+ # Output the class names and their probabilities
112
+ results = {}
113
+ for i in range(len(top3_classes)):
114
+ results[top3_classes[i]] = top3_probabilities[0][i].item()
115
+
116
+ return results
117
+
118
+ # Example usage
119
+ image_url = "https://example.com/path_to_your_image.jpg" # Replace with actual image URL or local path
120
+ title = "Organic Green Tea"
121
+ results = predict(image_url, title)
122
+
123
+ print("Prediction Results:")
124
+ for class_name, prob in results.items():
125
+ print(f"Class: {class_name}, Probability: {prob}")
126
+
127
+ # Map the top 3 indices to class names
128
+ top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]]
129
+
130
+ # Output the class names and their probabilities
131
+ for i in range(3):
132
+ print(f"Class: {top3_classes[i]}, Probability: {top3_probabilities[0][i].item()}")
133
+
134
+ ```
135
+
136
+ # Training Details
137
+
138
+ ## Training Data
139
+
140
+ EcommerceClassifier was trained on a dataset scraped from Amazon, focusing on two primary product nodes:
141
+
142
+ - **Grocery & Gourmet**
143
+ - **Health & Household**
144
+
145
+ The dataset includes over 434 categories with product images and titles, providing a comprehensive basis for training the model.
146
+
147
+ ## Training Procedure
148
+
149
+ ### Preprocessing:
150
+
151
+ - Images were resized to 224x224 pixels.
152
+ - Titles were tokenized using Jina’s embedding model.
153
+ - Data augmentation techniques such as random horizontal flip, random rotation, and color jitter were applied to images during training.
154
+
155
+ ### Training Hyperparameters:
156
+
157
+ - **Training regime:** Mixed precision (fp16)
158
+ - **Optimizer:** AdamW
159
+ - **Learning Rate:** 1e-4
160
+ - **Epochs:** 20
161
+ - **Batch Size:** 8
162
+ - **Accumulation Steps:** 4
163
+
164
+ ### Speeds, Sizes, Times:
165
+
166
+ The model was trained over 20 epochs using an NVIDIA A10 GPU, with each epoch taking approximately 30 minutes.
167
+
168
+ # Evaluation
169
+
170
+ ## Testing Data, Factors & Metrics
171
+
172
+ ### Testing Data
173
+
174
+ The model was evaluated on a validation dataset held out from the training data. The testing data includes a balanced representation of all 434 categories.
175
+
176
+ ### Factors
177
+
178
+ Evaluation factors include subpopulations within the Grocery & Gourmet and Health & Household categories.
179
+
180
+ ### Metrics
181
+
182
+ The model was evaluated using the following metrics:
183
+
184
+ - **Accuracy:** The overall correctness of the model's predictions.
185
+ - **Precision and Recall:** Evaluated per class to ensure balanced performance across all categories.
186
+
187
+ ## Results
188
+
189
+ The model achieved an overall accuracy of 83%, with a balanced precision and recall across most categories. Precision and recall tend to be low in the aggregated classes such as assortments, gift pack etc. The "others" category effectively captured instances where the model's confidence in the top predictions was low.
190
+
191
+ ## Summary
192
+
193
+ EcommerceClassifier demonstrated strong performance across the majority of categories, with particular strengths in well-represented classes. Future work may focus on enhancing performance in categories with fewer training examples.
194
+
195
+ # Environmental Impact
196
+
197
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in Lacoste et al. (2019).
198
+
199
+ - **Hardware Type:** NVIDIA A10 GPUs
200
+ - **Hours used:** ~10 hours total training time
201
+
202
+ # Technical Specifications
203
+
204
+ ## Model Architecture and Objective
205
+
206
+ The model consists of a ResNet50-based image encoder and a Jina embeddings-based text encoder, combined through fully connected layers to classify into 434 categories.
207
+
208
+ ## Compute Infrastructure
209
+
210
+ - **Hardware:** NVIDIA A10 GPUs
211
+ - **Software:** The model was implemented using PyTorch and Hugging Face Transformers libraries.
212
+
213
+
214
+ # Model Card Authors
215
+
216
+ Mohit Dhawan
217
+
218
+ # Model Card Contact
219
+
220
+ For inquiries, please contact [[email protected]]