Devon12 commited on
Commit
8b5f299
·
verified ·
1 Parent(s): 2d355ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -20
app.py CHANGED
@@ -22,14 +22,15 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  model = model.to(device)
23
  model.eval()
24
 
25
- # Define the image transformations (adjust as needed for your model)
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
29
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
30
  ])
31
 
32
- # Define class labels
33
  class_labels = [
34
  "Spraydosen", "Lebensmitteldosen_aus_Aluminium", "Getränkedosen_aus_Aluminium",
35
  "Pappkartons", "Verpackungen_aus_Pappe", "Kleidung", "Kaffeekränzchen",
@@ -43,9 +44,10 @@ class_labels = [
43
  "Teebeutel"
44
  ]
45
 
 
46
  class_to_tonne = {
47
  "Spraydosen": "Gelbe Tonne",
48
- "Lebensmitteldosen_aus_Aluminiumm": "Gelbe Tonne",
49
  "Getränkedosen_aus_Aluminium": "Gelbe Tonne",
50
  "Plastiktüten": "Gelbe Tonne",
51
  "Wasserflaschen_aus_Kunststoff": "Gelbe Tonne",
@@ -67,7 +69,7 @@ class_to_tonne = {
67
  "Teebeutel": "Biomüll",
68
  "Getränkeflaschen_aus_Glas": "Altglas",
69
  "Lebensmittelgläser_aus_Glas": "Altglas",
70
- "Losmetikbehälter_aus_Glas": "Altglas",
71
  "Kleidung": "Restmüll",
72
  "Schuhe": "Restmüll",
73
  "Styroporbecher": "Restmüll",
@@ -75,22 +77,10 @@ class_to_tonne = {
75
  "Mülltüten_aus_Kunststoff": "Restmüll"
76
  }
77
 
78
-
79
- # Prediction function
80
- def predict_image(image):
81
- if image.mode != "RGB":
82
- image = image.convert("RGB")
83
- input_tensor = transform(image).unsqueeze(0).to(model.device)
84
- with torch.no_grad():
85
- outputs = model(input_tensor)
86
- _, predicted = torch.max(outputs, 1)
87
- label = class_labels[predicted.item()]
88
- tonne = class_to_tonne.get(label, "Unbekannt")
89
- return label.replace("_", " ").capitalize(), tonne
90
-
91
- # Gradio interface setup
92
  with gr.Blocks(title="Abfallerkennung mit KI 🗑️") as demo:
93
  gr.Markdown("### Lade ein Bild hoch und erfahre, in welche Tonne der Abfall gehört.")
 
94
  image_input = gr.Image(type="pil", label="Bild hochladen")
95
  output_label = gr.Textbox(label="Vorhersage")
96
  output_tonne = gr.Textbox(label="Richtiger Abfallbehälter")
@@ -99,7 +89,7 @@ with gr.Blocks(title="Abfallerkennung mit KI 🗑️") as demo:
99
  def analyze(image):
100
  if image.mode != "RGB":
101
  image = image.convert("RGB")
102
- input_tensor = transform(image).unsqueeze(0).to(model.device)
103
  with torch.no_grad():
104
  outputs = model(input_tensor)
105
  _, predicted = torch.max(outputs, 1)
 
22
  model = model.to(device)
23
  model.eval()
24
 
25
+ # Define the image transformations
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
30
+ std=[0.229, 0.224, 0.225])
31
  ])
32
 
33
+ # Define class labels (German)
34
  class_labels = [
35
  "Spraydosen", "Lebensmitteldosen_aus_Aluminium", "Getränkedosen_aus_Aluminium",
36
  "Pappkartons", "Verpackungen_aus_Pappe", "Kleidung", "Kaffeekränzchen",
 
44
  "Teebeutel"
45
  ]
46
 
47
+ # Mapping classes to correct trash bin
48
  class_to_tonne = {
49
  "Spraydosen": "Gelbe Tonne",
50
+ "Lebensmitteldosen_aus_Aluminium": "Gelbe Tonne",
51
  "Getränkedosen_aus_Aluminium": "Gelbe Tonne",
52
  "Plastiktüten": "Gelbe Tonne",
53
  "Wasserflaschen_aus_Kunststoff": "Gelbe Tonne",
 
69
  "Teebeutel": "Biomüll",
70
  "Getränkeflaschen_aus_Glas": "Altglas",
71
  "Lebensmittelgläser_aus_Glas": "Altglas",
72
+ "Kosmetikbehälter_aus_Glas": "Altglas",
73
  "Kleidung": "Restmüll",
74
  "Schuhe": "Restmüll",
75
  "Styroporbecher": "Restmüll",
 
77
  "Mülltüten_aus_Kunststoff": "Restmüll"
78
  }
79
 
80
+ # Gradio Blocks UI
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  with gr.Blocks(title="Abfallerkennung mit KI 🗑️") as demo:
82
  gr.Markdown("### Lade ein Bild hoch und erfahre, in welche Tonne der Abfall gehört.")
83
+
84
  image_input = gr.Image(type="pil", label="Bild hochladen")
85
  output_label = gr.Textbox(label="Vorhersage")
86
  output_tonne = gr.Textbox(label="Richtiger Abfallbehälter")
 
89
  def analyze(image):
90
  if image.mode != "RGB":
91
  image = image.convert("RGB")
92
+ input_tensor = transform(image).unsqueeze(0).to(device)
93
  with torch.no_grad():
94
  outputs = model(input_tensor)
95
  _, predicted = torch.max(outputs, 1)