Shirts_Buttons / app.py
sagivp's picture
Update app.py
c40df25 verified
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoConfig
import gradio as gr
from PIL import Image
import os
import logging
from safetensors.torch import load_file # Import safetensors loading function
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Define the directory containing the model files
model_dir = "." # Use current directory
# Define paths to the specific model files
model_path = os.path.join(model_dir, "model.safetensors")
config_path = os.path.join(model_dir, "config.json")
preprocessor_path = os.path.join(model_dir, "preprocessor_config.json")
# Check if all required files exist
for path in [model_path, config_path, preprocessor_path]:
if not os.path.exists(path):
logging.error(f"File not found: {path}")
raise FileNotFoundError(f"Required file not found: {path}")
else:
logging.info(f"Found file: {path}")
# Load the configuration
config = AutoConfig.from_pretrained(config_path)
# Ensure the labels are consistent with the model's config
labels = list(config.id2label.values())
logging.info(f"Labels: {labels}")
# Load the feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained(preprocessor_path)
# Load the model using the safetensors file
state_dict = load_file(model_path) # Use safetensors to load the model weights
model = ViTForImageClassification.from_pretrained(
pretrained_model_name_or_path=None,
config=config,
state_dict=state_dict
)
# Ensure the model is in evaluation mode
model.eval()
logging.info("Model set to evaluation mode")
# Define the prediction function
def predict(image):
logging.info("Starting prediction")
logging.info(f"Input image shape: {image.size}")
# Preprocess the image
logging.info("Preprocessing image")
inputs = feature_extractor(images=image, return_tensors="pt")
logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}")
logging.info("Running inference")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits[0], dim=0)
logging.info(f"Raw logits: {logits}")
logging.info(f"Probabilities: {probabilities}")
# Prepare the output dictionary
result = {labels[i]: float(probabilities[i]) for i in range(len(labels))}
logging.info(f"Prediction result: {result}")
return result
# Set up the Gradio Interface
logging.info("Setting up Gradio interface")
gradio_app = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=6),
title="Shirts Buttons Classifier"
)
# Launch the app
if __name__ == "__main__":
logging.info("Launching the app")
gradio_app.launch()