import gradio as gr import torch from torchvision import transforms, models from torch import nn from PIL import Image # Load the model architecture model = models.resnet50(weights=None) num_classes = 30 num_features = model.fc.in_features model.fc = nn.Linear(num_features, num_classes) # Load the trained model weights try: model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu'))) print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") # Load your trained model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() # Define the image transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Define class labels (German) class_labels = [ "Spraydosen", "Lebensmitteldosen_aus_Aluminium", "Getränkedosen_aus_Aluminium", "Pappkartons", "Verpackungen_aus_Pappe", "Kleidung", "Kaffeekränzchen", "Einweg_Besteck_aus_Kunststoff", "Eierschalen", "Lebensmittelabfälle", "Getränkeflaschen_aus_Glas", "Kosmetikbehälter_aus_Glas", "Lebensmittelgläser_aus_Glas", "Zeitschriften", "Zeitungen", "Buropapier", "Pappbecher", "Deckel_aus_Kunststoff", "Waschmittelbehälter_aus_Kunststoff", "Lebensmittelbehälter_aus_Kunststoff", "Plastiktüten", "Sodaflaschen_aus_Kunststoff", "Strohhalme_aus_Kunststoff", "Mülltüten_aus_Kunststoff", "Wasserflaschen_aus_Kunststoff", "Schuhe", "Lebensmitteldosen_aus_Stahl", "Styroporbecher", "Lebensmittelbehälter_aus_Styropor", "Teebeutel" ] # Mapping classes to correct trash bin class_to_tonne = { "Spraydosen": "Gelbe Tonne", "Lebensmitteldosen_aus_Aluminium": "Gelbe Tonne", "Getränkedosen_aus_Aluminium": "Gelbe Tonne", "Plastiktüten": "Gelbe Tonne", "Wasserflaschen_aus_Kunststoff": "Gelbe Tonne", "Sodaflaschen_aus_Kunststoff": "Gelbe Tonne", "Lebensmittelbehälter_aus_Kunststoff": "Gelbe Tonne", "Waschmittelbehälter_aus_Kunststoff": "Gelbe Tonne", "Deckel_aus_Kunststoff": "Gelbe Tonne", "Strohhalme_aus_Kunststoff": "Gelbe Tonne", "Einweg_Besteck_aus_Kunststoff": "Gelbe Tonne", "Pappkartons": "Papiertonne", "Verpackungen_aus_Pappe": "Papiertonne", "Zeitschriften": "Papiertonne", "Zeitungen": "Papiertonne", "Buropapier": "Papiertonne", "Pappbecher": "Papiertonne", "Kaffeekränzchen": "Biomüll", "Lebensmittelabfälle": "Biomüll", "Eierschalen": "Biomüll", "Teebeutel": "Biomüll", "Getränkeflaschen_aus_Glas": "Altglas", "Lebensmittelgläser_aus_Glas": "Altglas", "Kosmetikbehälter_aus_Glas": "Altglas", "Kleidung": "Restmüll", "Schuhe": "Restmüll", "Styroporbecher": "Restmüll", "Lebensmittelbehälter_aus_Styropor": "Restmüll", "Mülltüten_aus_Kunststoff": "Restmüll" } # Gradio Blocks UI with gr.Blocks(title="Abfallerkennung mit KI 🗑️") as demo: gr.Markdown("### Lade ein Bild hoch und erfahre, in welche Tonne der Abfall gehört.") image_input = gr.Image(type="pil", label="Bild hochladen") output_label = gr.Textbox(label="Vorhersage") output_tonne = gr.Textbox(label="Richtiger Abfallbehälter") button = gr.Button("Analysieren") def analyze(image): if image.mode != "RGB": image = image.convert("RGB") input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(input_tensor) _, predicted = torch.max(outputs, 1) label = class_labels[predicted.item()] tonne = class_to_tonne.get(label, "Unbekannt") return label.replace("_", " ").capitalize(), tonne button.click(fn=analyze, inputs=image_input, outputs=[output_label, output_tonne]) demo.launch(share=True)