Sechskomanull / app.py
Devon12's picture
Add Gradio app, requirements, and model
4025a0b
raw
history blame
2.14 kB
import gradio as gr
import torch
from torchvision import transforms, models
from torch import nn
from PIL import Image
# Load the model architecture
model = models.resnet50(weights=None)
num_classes = 30
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
# Load the trained model weights
try:
model.load_state_dict(torch.load("best.pt", map_location=torch.device('cpu')))
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# Load your trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# Define the image transformations (adjust as needed for your model)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Define class labels
class_labels = [
"aerosol_cans", "aluminum_food_cans", "aluminum_soda_cans", "cardboard_boxes",
"cardboard_packaging", "clothing", "coffee_grounds", "disposable_plastic_cutlery",
"eggshells", "food_waste", "glass_beverage_bottles", "glass_cosmetic_containers",
"glass_food_jars", "magazines", "newspaper", "office_paper", "paper_cups",
"plastic_cup_lids", "plastic_detergent_bottles", "plastic_food_containers",
"plastic_shopping_bags", "plastic_soda_bottles", "plastic_straws", "plastic_trash_bags",
"plastic_water_bottles", "shoes", "steel_food_cans", "styrofoam_cups",
"styrofoam_food_containers", "tea_bags"
]
# Prediction function
def predict_image(image):
if image.mode != "RGB":
image = image.convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
label = class_labels[predicted.item()]
return label
# Gradio interface setup
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="Upload Image or Use Webcam"),
outputs="text",
live=True
)
# Launch Gradio app
interface.launch()