Maverick98 commited on
Commit
64ffe28
·
verified ·
1 Parent(s): 699133d
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -78,16 +78,14 @@ label_to_class = requests.get(label_map_url).json()
78
  # Load your custom model from Hugging Face
79
  model = FineGrainedClassifier(num_classes=len(label_to_class))
80
  model_checkpoint = "Maverick98/EcommerceClassifier"
81
- model.load_state_dict(torch.hub.load_state_dict_from_url(f"https://huggingface.co/{model_checkpoint}/resolve/main/model_checkpoint.pth", map_location=torch.device('cpu')))
 
 
 
 
82
  # Load the tokenizer from Jina
83
  tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
84
 
85
- # # Define image preprocessing
86
- # transform = transforms.Compose([
87
- # transforms.Resize((224, 224)),
88
- # transforms.ToTensor(),
89
- # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
90
- # ])
91
 
92
  def load_image(image_path_or_url):
93
  """
 
78
  # Load your custom model from Hugging Face
79
  model = FineGrainedClassifier(num_classes=len(label_to_class))
80
  model_checkpoint = "Maverick98/EcommerceClassifier"
81
+ checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth"
82
+ checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))
83
+ # Extract and load the model state_dict
84
+ model.load_state_dict(checkpoint['model_state_dict'])
85
+
86
  # Load the tokenizer from Jina
87
  tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
88
 
 
 
 
 
 
 
89
 
90
  def load_image(image_path_or_url):
91
  """