File size: 3,955 Bytes
4025a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
355cc40
4025a0b
 
 
 
 
 
 
 
 
8b5f299
4025a0b
 
 
8b5f299
 
4025a0b
 
8b5f299
4025a0b
db5601b
 
 
 
 
 
 
 
 
 
4025a0b
 
8b5f299
db5601b
 
8b5f299
db5601b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b5f299
db5601b
 
 
 
 
 
 
60ca92a
33e388f
 
60ca92a
33e388f
 
 
 
473d641
33e388f
 
 
60ca92a
33e388f
 
 
 
 
60ca92a
473d641
33e388f
 
19cb8a0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import torch
from torchvision import transforms, models
from torch import nn
from PIL import Image

# Load the model architecture
model = models.resnet50(weights=None)
num_classes = 30
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)

# Load the trained model weights
try:
    model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")

# Load your trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Define class labels (German)
class_labels = [
    "Spraydosen", "Lebensmitteldosen_aus_Aluminium", "Getränkedosen_aus_Aluminium",
    "Pappkartons", "Verpackungen_aus_Pappe", "Kleidung", "Kaffeekränzchen",
    "Einweg_Besteck_aus_Kunststoff", "Eierschalen", "Lebensmittelabfälle",
    "Getränkeflaschen_aus_Glas", "Kosmetikbehälter_aus_Glas", "Lebensmittelgläser_aus_Glas",
    "Zeitschriften", "Zeitungen", "Buropapier", "Pappbecher", "Deckel_aus_Kunststoff",
    "Waschmittelbehälter_aus_Kunststoff", "Lebensmittelbehälter_aus_Kunststoff",
    "Plastiktüten", "Sodaflaschen_aus_Kunststoff", "Strohhalme_aus_Kunststoff",
    "Mülltüten_aus_Kunststoff", "Wasserflaschen_aus_Kunststoff", "Schuhe",
    "Lebensmitteldosen_aus_Stahl", "Styroporbecher", "Lebensmittelbehälter_aus_Styropor",
    "Teebeutel"
]

# Mapping classes to correct trash bin
class_to_tonne = {
    "Spraydosen": "Gelbe Tonne",
    "Lebensmitteldosen_aus_Aluminium": "Gelbe Tonne",
    "Getränkedosen_aus_Aluminium": "Gelbe Tonne",
    "Plastiktüten": "Gelbe Tonne",
    "Wasserflaschen_aus_Kunststoff": "Gelbe Tonne",
    "Sodaflaschen_aus_Kunststoff": "Gelbe Tonne",
    "Lebensmittelbehälter_aus_Kunststoff": "Gelbe Tonne",
    "Waschmittelbehälter_aus_Kunststoff": "Gelbe Tonne",
    "Deckel_aus_Kunststoff": "Gelbe Tonne",
    "Strohhalme_aus_Kunststoff": "Gelbe Tonne",
    "Einweg_Besteck_aus_Kunststoff": "Gelbe Tonne",
    "Pappkartons": "Papiertonne",
    "Verpackungen_aus_Pappe": "Papiertonne",
    "Zeitschriften": "Papiertonne",
    "Zeitungen": "Papiertonne",
    "Buropapier": "Papiertonne",
    "Pappbecher": "Papiertonne",
    "Kaffeekränzchen": "Biomüll",
    "Lebensmittelabfälle": "Biomüll",
    "Eierschalen": "Biomüll",
    "Teebeutel": "Biomüll",
    "Getränkeflaschen_aus_Glas": "Altglas",
    "Lebensmittelgläser_aus_Glas": "Altglas",
    "Kosmetikbehälter_aus_Glas": "Altglas",
    "Kleidung": "Restmüll",
    "Schuhe": "Restmüll",
    "Styroporbecher": "Restmüll",
    "Lebensmittelbehälter_aus_Styropor": "Restmüll",
    "Mülltüten_aus_Kunststoff": "Restmüll"
}

# Gradio Blocks UI
with gr.Blocks(title="Abfallerkennung mit KI 🗑️") as demo:
    gr.Markdown("### Lade ein Bild hoch und erfahre, in welche Tonne der Abfall gehört.")
    
    image_input = gr.Image(type="pil", label="Bild hochladen")
    output_label = gr.Textbox(label="Vorhersage")
    output_tonne = gr.Textbox(label="Richtiger Abfallbehälter")
    button = gr.Button("Analysieren")

    def analyze(image):
        if image.mode != "RGB":
            image = image.convert("RGB")
        input_tensor = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(input_tensor)
            _, predicted = torch.max(outputs, 1)
            label = class_labels[predicted.item()]
            tonne = class_to_tonne.get(label, "Unbekannt")
        return label.replace("_", " ").capitalize(), tonne

    button.click(fn=analyze, inputs=image_input, outputs=[output_label, output_tonne])

demo.launch(share=True)