Spaces:
Sleeping
Sleeping
File size: 3,955 Bytes
4025a0b 355cc40 4025a0b 8b5f299 4025a0b 8b5f299 4025a0b 8b5f299 4025a0b db5601b 4025a0b 8b5f299 db5601b 8b5f299 db5601b 8b5f299 db5601b 60ca92a 33e388f 60ca92a 33e388f 473d641 33e388f 60ca92a 33e388f 60ca92a 473d641 33e388f 19cb8a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
import torch
from torchvision import transforms, models
from torch import nn
from PIL import Image
# Load the model architecture
model = models.resnet50(weights=None)
num_classes = 30
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
# Load the trained model weights
try:
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# Load your trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# Define the image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Define class labels (German)
class_labels = [
"Spraydosen", "Lebensmitteldosen_aus_Aluminium", "Getränkedosen_aus_Aluminium",
"Pappkartons", "Verpackungen_aus_Pappe", "Kleidung", "Kaffeekränzchen",
"Einweg_Besteck_aus_Kunststoff", "Eierschalen", "Lebensmittelabfälle",
"Getränkeflaschen_aus_Glas", "Kosmetikbehälter_aus_Glas", "Lebensmittelgläser_aus_Glas",
"Zeitschriften", "Zeitungen", "Buropapier", "Pappbecher", "Deckel_aus_Kunststoff",
"Waschmittelbehälter_aus_Kunststoff", "Lebensmittelbehälter_aus_Kunststoff",
"Plastiktüten", "Sodaflaschen_aus_Kunststoff", "Strohhalme_aus_Kunststoff",
"Mülltüten_aus_Kunststoff", "Wasserflaschen_aus_Kunststoff", "Schuhe",
"Lebensmitteldosen_aus_Stahl", "Styroporbecher", "Lebensmittelbehälter_aus_Styropor",
"Teebeutel"
]
# Mapping classes to correct trash bin
class_to_tonne = {
"Spraydosen": "Gelbe Tonne",
"Lebensmitteldosen_aus_Aluminium": "Gelbe Tonne",
"Getränkedosen_aus_Aluminium": "Gelbe Tonne",
"Plastiktüten": "Gelbe Tonne",
"Wasserflaschen_aus_Kunststoff": "Gelbe Tonne",
"Sodaflaschen_aus_Kunststoff": "Gelbe Tonne",
"Lebensmittelbehälter_aus_Kunststoff": "Gelbe Tonne",
"Waschmittelbehälter_aus_Kunststoff": "Gelbe Tonne",
"Deckel_aus_Kunststoff": "Gelbe Tonne",
"Strohhalme_aus_Kunststoff": "Gelbe Tonne",
"Einweg_Besteck_aus_Kunststoff": "Gelbe Tonne",
"Pappkartons": "Papiertonne",
"Verpackungen_aus_Pappe": "Papiertonne",
"Zeitschriften": "Papiertonne",
"Zeitungen": "Papiertonne",
"Buropapier": "Papiertonne",
"Pappbecher": "Papiertonne",
"Kaffeekränzchen": "Biomüll",
"Lebensmittelabfälle": "Biomüll",
"Eierschalen": "Biomüll",
"Teebeutel": "Biomüll",
"Getränkeflaschen_aus_Glas": "Altglas",
"Lebensmittelgläser_aus_Glas": "Altglas",
"Kosmetikbehälter_aus_Glas": "Altglas",
"Kleidung": "Restmüll",
"Schuhe": "Restmüll",
"Styroporbecher": "Restmüll",
"Lebensmittelbehälter_aus_Styropor": "Restmüll",
"Mülltüten_aus_Kunststoff": "Restmüll"
}
# Gradio Blocks UI
with gr.Blocks(title="Abfallerkennung mit KI 🗑️") as demo:
gr.Markdown("### Lade ein Bild hoch und erfahre, in welche Tonne der Abfall gehört.")
image_input = gr.Image(type="pil", label="Bild hochladen")
output_label = gr.Textbox(label="Vorhersage")
output_tonne = gr.Textbox(label="Richtiger Abfallbehälter")
button = gr.Button("Analysieren")
def analyze(image):
if image.mode != "RGB":
image = image.convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
label = class_labels[predicted.item()]
tonne = class_to_tonne.get(label, "Unbekannt")
return label.replace("_", " ").capitalize(), tonne
button.click(fn=analyze, inputs=image_input, outputs=[output_label, output_tonne])
demo.launch(share=True)
|