DeathDaDev's picture
Update app.py
a063777 verified
import gradio as gr
from transformers import AutoModelForImageClassification, AutoProcessor
import torch
# Load the model and processor
model_name = "DeathDaDev/Materializer"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
# Define the prediction function
def classify_image(image):
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Perform inference
with torch.no_grad():
logits = model(**inputs).logits
# Get the predicted class
predicted_class_idx = logits.argmax(-1).item()
return model.config.id2label[predicted_class_idx]
# Create the Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="Image Classification with Materializer",
description="This model has been trained on texture images that are commonly used for 3d models in an attempt to create an AI model that understands what image 'material' should be used on a specific object. Upload an image to classify it using the Materializer model."
)
# Launch the interface
iface.launch()