Sechskomanull / app.py
Devon12's picture
Add Gradio app
60ca92a
import gradio as gr
import torch
from torchvision import transforms, models
from torch import nn
from PIL import Image
# Load the model architecture
model = models.resnet50(weights=None)
num_classes = 30
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
# Load the trained model weights
try:
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# Load your trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# Define the image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Define class labels (German)
class_labels = [
"Spraydosen", "Lebensmitteldosen_aus_Aluminium", "Getränkedosen_aus_Aluminium",
"Pappkartons", "Verpackungen_aus_Pappe", "Kleidung", "Kaffeekränzchen",
"Einweg_Besteck_aus_Kunststoff", "Eierschalen", "Lebensmittelabfälle",
"Getränkeflaschen_aus_Glas", "Kosmetikbehälter_aus_Glas", "Lebensmittelgläser_aus_Glas",
"Zeitschriften", "Zeitungen", "Buropapier", "Pappbecher", "Deckel_aus_Kunststoff",
"Waschmittelbehälter_aus_Kunststoff", "Lebensmittelbehälter_aus_Kunststoff",
"Plastiktüten", "Sodaflaschen_aus_Kunststoff", "Strohhalme_aus_Kunststoff",
"Mülltüten_aus_Kunststoff", "Wasserflaschen_aus_Kunststoff", "Schuhe",
"Lebensmitteldosen_aus_Stahl", "Styroporbecher", "Lebensmittelbehälter_aus_Styropor",
"Teebeutel"
]
# Mapping classes to correct trash bin
class_to_tonne = {
"Spraydosen": "Gelbe Tonne",
"Lebensmitteldosen_aus_Aluminium": "Gelbe Tonne",
"Getränkedosen_aus_Aluminium": "Gelbe Tonne",
"Plastiktüten": "Gelbe Tonne",
"Wasserflaschen_aus_Kunststoff": "Gelbe Tonne",
"Sodaflaschen_aus_Kunststoff": "Gelbe Tonne",
"Lebensmittelbehälter_aus_Kunststoff": "Gelbe Tonne",
"Waschmittelbehälter_aus_Kunststoff": "Gelbe Tonne",
"Deckel_aus_Kunststoff": "Gelbe Tonne",
"Strohhalme_aus_Kunststoff": "Gelbe Tonne",
"Einweg_Besteck_aus_Kunststoff": "Gelbe Tonne",
"Pappkartons": "Papiertonne",
"Verpackungen_aus_Pappe": "Papiertonne",
"Zeitschriften": "Papiertonne",
"Zeitungen": "Papiertonne",
"Buropapier": "Papiertonne",
"Pappbecher": "Papiertonne",
"Kaffeekränzchen": "Biomüll",
"Lebensmittelabfälle": "Biomüll",
"Eierschalen": "Biomüll",
"Teebeutel": "Biomüll",
"Getränkeflaschen_aus_Glas": "Altglas",
"Lebensmittelgläser_aus_Glas": "Altglas",
"Kosmetikbehälter_aus_Glas": "Altglas",
"Kleidung": "Restmüll",
"Schuhe": "Restmüll",
"Styroporbecher": "Restmüll",
"Lebensmittelbehälter_aus_Styropor": "Restmüll",
"Mülltüten_aus_Kunststoff": "Restmüll"
}
# Gradio Blocks UI
with gr.Blocks(title="Abfallerkennung mit KI 🗑️") as demo:
gr.Markdown("### Lade ein Bild hoch und erfahre, in welche Tonne der Abfall gehört.")
image_input = gr.Image(type="pil", label="Bild hochladen")
output_label = gr.Textbox(label="Vorhersage")
output_tonne = gr.Textbox(label="Richtiger Abfallbehälter")
button = gr.Button("Analysieren")
def analyze(image):
if image.mode != "RGB":
image = image.convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
label = class_labels[predicted.item()]
tonne = class_to_tonne.get(label, "Unbekannt")
return label.replace("_", " ").capitalize(), tonne
button.click(fn=analyze, inputs=image_input, outputs=[output_label, output_tonne])
demo.launch(share=True)