EdBoy2202 commited on
Commit
ad0c1dd
·
verified ·
1 Parent(s): b24890a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -8
app.py CHANGED
@@ -9,7 +9,8 @@ import matplotlib.pyplot as plt
9
  import numpy as np
10
  from sklearn.preprocessing import LabelEncoder
11
  from huggingface_hub import hf_hub_download
12
- from transformers import pipeline
 
13
  from sklearn.feature_extraction.text import TfidfVectorizer
14
  from sklearn.metrics.pairwise import cosine_similarity
15
  import re
@@ -31,14 +32,28 @@ def load_image(image_file):
31
 
32
  def classify_image(image):
33
  try:
34
- # Create a pipeline for image classification
35
- classifier = pipeline('image-classification', model="dima806/car_models_image_detection", device=-1) # Use -1 for CPU, or 0 for GPU if available
36
-
37
- # Classify the image
38
- results = classifier(image)
 
 
 
 
 
 
 
 
 
 
39
 
40
- # Return top 5 predictions
41
- return results[:5]
 
 
 
 
42
 
43
  except Exception as e:
44
  st.error(f"Classification error: {e}")
 
9
  import numpy as np
10
  from sklearn.preprocessing import LabelEncoder
11
  from huggingface_hub import hf_hub_download
12
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
13
+ import torch
14
  from sklearn.feature_extraction.text import TfidfVectorizer
15
  from sklearn.metrics.pairwise import cosine_similarity
16
  import re
 
32
 
33
  def classify_image(image):
34
  try:
35
+ # Load the model and feature extractor
36
+ model_name = "dima806/car_models_image_detection"
37
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
38
+ model = AutoModelForImageClassification.from_pretrained(model_name)
39
+
40
+ # Preprocess the image
41
+ inputs = feature_extractor(images=image, return_tensors="pt")
42
+
43
+ # Perform inference
44
+ with torch.no_grad():
45
+ outputs = model(**inputs)
46
+
47
+ # Get the predicted class
48
+ logits = outputs.logits
49
+ predicted_class_idx = logits.argmax(-1).item()
50
 
51
+ # Get the class label and score
52
+ predicted_class_label = model.config.id2label[predicted_class_idx]
53
+ score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
54
+
55
+ # Return the top prediction
56
+ return [{'label': predicted_class_label, 'score': score}]
57
 
58
  except Exception as e:
59
  st.error(f"Classification error: {e}")