EdBoy2202 commited on
Commit
57babb1
·
verified ·
1 Parent(s): fbddb10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -35
app.py CHANGED
@@ -10,6 +10,7 @@ import numpy as np
10
  from sklearn.preprocessing import LabelEncoder
11
  from huggingface_hub import hf_hub_download
12
  import base64
 
13
 
14
  # Dataset loading function with caching
15
  @st.cache_data
@@ -33,37 +34,18 @@ def load_image(image_file):
33
 
34
 
35
  def classify_image(image):
36
- # Convert PIL Image to bytes
37
- img_byte_arr = BytesIO()
38
- image.save(img_byte_arr, format='PNG')
39
- img_byte_arr = img_byte_arr.getvalue()
40
-
41
- # Encode image to base64
42
- encoded_image = base64.b64encode(img_byte_arr).decode('ascii')
43
-
44
- headers = {
45
- "Authorization": f"Bearer {HUGGINGFACE_API_KEY}",
46
- "Content-Type": "application/json"
47
- }
48
 
49
- payload = {
50
- "inputs": encoded_image
51
- }
52
-
53
- try:
54
- response = requests.post(
55
- 'https://api-inference.huggingface.co/models/dima806/car_models_image_detection',
56
- headers=headers,
57
- json=payload
58
- )
59
- response.raise_for_status() # Raises an HTTPError for bad responses
60
- return response.json()
61
- except requests.exceptions.RequestException as e:
62
- st.error(f"Image classification failed: {e}")
63
- if response.text:
64
- st.error(f"API Response: {response.text}")
65
- return None
66
-
67
  def find_closest_match(df, brand, model):
68
  match = df[(df['Make'].str.contains(brand, case=False)) & (df['Model'].str.contains(model, case=False))]
69
  if not match.empty:
@@ -139,11 +121,7 @@ if camera_image is not None:
139
  car_info = classify_image(image)
140
 
141
  if car_info:
142
- brand = car_info.get('brand', None) # Adjust according to the response structure
143
- model_name = car_info.get('model', None)
144
-
145
- if brand and model_name:
146
- st.write(f"Identified Car: {brand} {model_name}")
147
 
148
  # Find the closest match in the CSV
149
  match = find_closest_match(df, brand, model_name)
 
10
  from sklearn.preprocessing import LabelEncoder
11
  from huggingface_hub import hf_hub_download
12
  import base64
13
+ from transformers import ViTImageProcessor, ViTForImageClassification
14
 
15
  # Dataset loading function with caching
16
  @st.cache_data
 
34
 
35
 
36
  def classify_image(image):
37
+ processor = ViTImageProcessor.from_pretrained("dima806/car_models_image_detection")
38
+ model = ViTForImageClassification.from_pretrained("dima806/car_models_image_detection")
 
 
 
 
 
 
 
 
 
 
39
 
40
+ inputs = processor(images=image, return_tensors="pt")
41
+
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+
45
+ logits = outputs.logits
46
+ predicted_class_idx = logits.argmax(-1).item()
47
+
48
+ return model.config.id2label[predicted_class_idx]
 
 
 
 
 
 
 
 
 
49
  def find_closest_match(df, brand, model):
50
  match = df[(df['Make'].str.contains(brand, case=False)) & (df['Model'].str.contains(model, case=False))]
51
  if not match.empty:
 
121
  car_info = classify_image(image)
122
 
123
  if car_info:
124
+ st.write(f"Identified Car: {car_info}")
 
 
 
 
125
 
126
  # Find the closest match in the CSV
127
  match = find_closest_match(df, brand, model_name)