Spaces:
Sleeping
Sleeping
File size: 3,955 Bytes
4025a0b 355cc40 4025a0b 8b5f299 4025a0b 8b5f299 4025a0b 8b5f299 4025a0b db5601b 4025a0b 8b5f299 db5601b 8b5f299 db5601b 8b5f299 db5601b 33e388f 473d641 33e388f 3d4cfbe 33e388f 473d641 33e388f 3d4cfbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
import torch
from torchvision import transforms, models
from torch import nn
from PIL import Image
# Load the model architecture
model = models.resnet50(weights=None)
num_classes = 30
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
# Load the trained model weights
try:
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# Load your trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# Define the image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Define class labels (German)
class_labels = [
"Spraydosen", "Lebensmitteldosen_aus_Aluminium", "Getränkedosen_aus_Aluminium",
"Pappkartons", "Verpackungen_aus_Pappe", "Kleidung", "Kaffeekränzchen",
"Einweg_Besteck_aus_Kunststoff", "Eierschalen", "Lebensmittelabfälle",
"Getränkeflaschen_aus_Glas", "Kosmetikbehälter_aus_Glas", "Lebensmittelgläser_aus_Glas",
"Zeitschriften", "Zeitungen", "Buropapier", "Pappbecher", "Deckel_aus_Kunststoff",
"Waschmittelbehälter_aus_Kunststoff", "Lebensmittelbehälter_aus_Kunststoff",
"Plastiktüten", "Sodaflaschen_aus_Kunststoff", "Strohhalme_aus_Kunststoff",
"Mülltüten_aus_Kunststoff", "Wasserflaschen_aus_Kunststoff", "Schuhe",
"Lebensmitteldosen_aus_Stahl", "Styroporbecher", "Lebensmittelbehälter_aus_Styropor",
"Teebeutel"
]
# Mapping classes to correct trash bin
class_to_tonne = {
"Spraydosen": "Gelbe Tonne",
"Lebensmitteldosen_aus_Aluminium": "Gelbe Tonne",
"Getränkedosen_aus_Aluminium": "Gelbe Tonne",
"Plastiktüten": "Gelbe Tonne",
"Wasserflaschen_aus_Kunststoff": "Gelbe Tonne",
"Sodaflaschen_aus_Kunststoff": "Gelbe Tonne",
"Lebensmittelbehälter_aus_Kunststoff": "Gelbe Tonne",
"Waschmittelbehälter_aus_Kunststoff": "Gelbe Tonne",
"Deckel_aus_Kunststoff": "Gelbe Tonne",
"Strohhalme_aus_Kunststoff": "Gelbe Tonne",
"Einweg_Besteck_aus_Kunststoff": "Gelbe Tonne",
"Pappkartons": "Papiertonne",
"Verpackungen_aus_Pappe": "Papiertonne",
"Zeitschriften": "Papiertonne",
"Zeitungen": "Papiertonne",
"Buropapier": "Papiertonne",
"Pappbecher": "Papiertonne",
"Kaffeekränzchen": "Biomüll",
"Lebensmittelabfälle": "Biomüll",
"Eierschalen": "Biomüll",
"Teebeutel": "Biomüll",
"Getränkeflaschen_aus_Glas": "Altglas",
"Lebensmittelgläser_aus_Glas": "Altglas",
"Kosmetikbehälter_aus_Glas": "Altglas",
"Kleidung": "Restmüll",
"Schuhe": "Restmüll",
"Styroporbecher": "Restmüll",
"Lebensmittelbehälter_aus_Styropor": "Restmüll",
"Mülltüten_aus_Kunststoff": "Restmüll"
}
# Gradio Block
with gr.Blocks(title="Abfallerkennung mit KI 🗑️") as demo:
gr.Markdown("### Lade ein Bild hoch und erfahre, in welche Tonne der Abfall gehört.")
image_input = gr.Image(type="pil", label="Bild hochladen")
output_label = gr.Textbox(label="Vorhersage")
output_tonne = gr.Textbox(label="Richtiger Abfallbehälter")
button = gr.Button("Analysieren")
def analyze(image):
if image.mode != "RGB":
image = image.convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(model.device)
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
label = class_labels[predicted.item()]
tonne = class_to_tonne.get(label, "Unbekannt")
label_1 = label.replace("_", " ")
return label_1, tonne
button.click(fn=analyze, inputs=image_input, outputs=[output_label, output_tonne])
demo.launch()
|