Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from torchvision import models
|
4 |
+
from PIL import Image
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
# Load your trained ResNet-50 model
|
8 |
+
model = models.resnet50(pretrained=False) # Load ResNet-50 architecture
|
9 |
+
model.load_state_dict(torch.load("model.pth")) # Load the trained weights (.pth)
|
10 |
+
model.eval() # Set model to evaluation mode
|
11 |
+
|
12 |
+
# Define the transformation required for the input image
|
13 |
+
transform = transforms.Compose([
|
14 |
+
transforms.Resize(256),
|
15 |
+
transforms.CenterCrop(224),
|
16 |
+
transforms.ToTensor(),
|
17 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
18 |
+
])
|
19 |
+
|
20 |
+
# Define the labels for ImageNet (or your specific dataset labels)
|
21 |
+
# This is typically a list of class labels for classification
|
22 |
+
LABELS = ["class_1", "class_2", "class_3", "class_4", "class_5", # Replace with your classes
|
23 |
+
"class_6", "class_7", "class_8", "class_9", "class_10"]
|
24 |
+
|
25 |
+
# Define the prediction function
|
26 |
+
def predict(image):
|
27 |
+
image = Image.open(image).convert("RGB") # Open the image and convert to RGB
|
28 |
+
image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension
|
29 |
+
|
30 |
+
with torch.no_grad():
|
31 |
+
outputs = model(image) # Get model predictions
|
32 |
+
_, predicted = torch.max(outputs, 1) # Get the class with highest probability
|
33 |
+
return LABELS[predicted.item()] # Return the predicted class label
|
34 |
+
|
35 |
+
# Set up the Gradio interface
|
36 |
+
interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(type="pil"), outputs="text")
|
37 |
+
|
38 |
+
# Launch the interface
|
39 |
+
interface.launch()
|