Maverick98 commited on
Commit
12802d0
·
verified ·
1 Parent(s): 385048d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +50 -10
README.md CHANGED
@@ -51,19 +51,66 @@ Use the code below to get started with EcommerceClassifier:
51
 
52
  ```python
53
  import torch
54
- from transformers import AutoModel, AutoTokenizer
55
  import json
56
  import requests
57
  from PIL import Image
58
  from torchvision import transforms
59
  import urllib.request
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # Load the label-to-class mapping from Hugging Face
62
  label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
63
  label_to_class = requests.get(label_map_url).json()
64
 
65
- # Load the model and tokenizer
66
- model = AutoModel.from_pretrained("Maverick98/EcommerceClassifier")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
68
 
69
  # Define image preprocessing
@@ -124,13 +171,6 @@ print("Prediction Results:")
124
  for class_name, prob in results.items():
125
  print(f"Class: {class_name}, Probability: {prob}")
126
 
127
- # Map the top 3 indices to class names
128
- top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]]
129
-
130
- # Output the class names and their probabilities
131
- for i in range(3):
132
- print(f"Class: {top3_classes[i]}, Probability: {top3_probabilities[0][i].item()}")
133
-
134
  ```
135
 
136
  # Training Details
 
51
 
52
  ```python
53
  import torch
54
+ from transformers import AutoTokenizer, AutoModel
55
  import json
56
  import requests
57
  from PIL import Image
58
  from torchvision import transforms
59
  import urllib.request
60
+ import torch.nn as nn
61
+
62
+ # --- Define the Model ---
63
+ class FineGrainedClassifier(nn.Module):
64
+ def __init__(self, num_classes=434): # Updated to 434 classes
65
+ super(FineGrainedClassifier, self).__init__()
66
+ self.image_encoder = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
67
+ self.image_encoder.fc = nn.Identity()
68
+ self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en')
69
+ self.classifier = nn.Sequential(
70
+ nn.Linear(2048 + 768, 1024),
71
+ nn.BatchNorm1d(1024),
72
+ nn.ReLU(),
73
+ nn.Dropout(0.3),
74
+ nn.Linear(1024, 512),
75
+ nn.BatchNorm1d(512),
76
+ nn.ReLU(),
77
+ nn.Dropout(0.3),
78
+ nn.Linear(512, num_classes) # Updated to 434 classes
79
+ )
80
+
81
+ def forward(self, image, input_ids, attention_mask):
82
+ image_features = self.image_encoder(image)
83
+ text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
84
+ text_features = text_output.last_hidden_state[:, 0, :]
85
+ combined_features = torch.cat((image_features, text_features), dim=1)
86
+ output = self.classifier(combined_features)
87
+ return output
88
 
89
  # Load the label-to-class mapping from Hugging Face
90
  label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
91
  label_to_class = requests.get(label_map_url).json()
92
 
93
+ # Load the custom model
94
+ model = FineGrainedClassifier(num_classes=len(label_to_class))
95
+ checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth"
96
+ checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))
97
+
98
+ # Clean up the state dictionary
99
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
100
+ new_state_dict = {}
101
+ for k, v in state_dict.items():
102
+ if k.startswith("module."):
103
+ new_key = k[7:] # Remove "module." prefix
104
+ else:
105
+ new_key = k
106
+
107
+ # Check if the new_key exists in the model's state_dict, only add if it does
108
+ if new_key in model.state_dict():
109
+ new_state_dict[new_key] = v
110
+
111
+ model.load_state_dict(new_state_dict)
112
+
113
+ # Load the tokenizer from Jina
114
  tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
115
 
116
  # Define image preprocessing
 
171
  for class_name, prob in results.items():
172
  print(f"Class: {class_name}, Probability: {prob}")
173
 
 
 
 
 
 
 
 
174
  ```
175
 
176
  # Training Details