nitinnyadavvv commited on
Commit
bf4aba3
·
verified ·
1 Parent(s): 0fa5667

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import io
6
+
7
+ # Load model and feature extractor
8
+ def load_model():
9
+ processor = AutoImageProcessor.from_pretrained("therealcyberlord/stanford-car-vit-patch16")
10
+ model = AutoModelForImageClassification.from_pretrained("therealcyberlord/stanford-car-vit-patch16")
11
+ return processor, model
12
+
13
+ processor, model = load_model()
14
+
15
+ # Function to classify image
16
+ def classify_image(image):
17
+ # Convert image if necessary
18
+ if not isinstance(image, Image.Image):
19
+ image = Image.open(io.BytesIO(image)).convert("RGB")
20
+
21
+ inputs = processor(images=image, return_tensors="pt")
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
+ logits = outputs.logits
25
+ predicted_class_idx = logits.argmax(-1).item()
26
+ labels = model.config.id2label
27
+ predicted_class = labels[predicted_class_idx]
28
+ return predicted_class
29
+
30
+ # Define Gradio Interface
31
+ app = gr.Interface(
32
+ fn=classify_image,
33
+ inputs=gr.Image(type="pil"),
34
+ outputs="text",
35
+ title="Car Classification",
36
+ description="Upload a car image to classify its model."
37
+ )
38
+
39
+ # Launch the app
40
+ app.launch()