Spaces:
Sleeping
Sleeping
import torch | |
from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoConfig | |
import gradio as gr | |
from PIL import Image | |
import os | |
import logging | |
from safetensors.torch import load_file # Import safetensors loading function | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Define the directory containing the model files | |
model_dir = "." # Use current directory | |
# Define paths to the specific model files | |
model_path = os.path.join(model_dir, "model.safetensors") | |
config_path = os.path.join(model_dir, "config.json") | |
preprocessor_path = os.path.join(model_dir, "preprocessor_config.json") | |
# Check if all required files exist | |
for path in [model_path, config_path, preprocessor_path]: | |
if not os.path.exists(path): | |
logging.error(f"File not found: {path}") | |
raise FileNotFoundError(f"Required file not found: {path}") | |
else: | |
logging.info(f"Found file: {path}") | |
# Load the configuration | |
config = AutoConfig.from_pretrained(config_path) | |
# Ensure the labels are consistent with the model's config | |
labels = list(config.id2label.values()) | |
logging.info(f"Labels: {labels}") | |
# Load the feature extractor | |
feature_extractor = ViTFeatureExtractor.from_pretrained(preprocessor_path) | |
# Load the model using the safetensors file | |
state_dict = load_file(model_path) # Use safetensors to load the model weights | |
model = ViTForImageClassification.from_pretrained( | |
pretrained_model_name_or_path=None, | |
config=config, | |
state_dict=state_dict | |
) | |
# Ensure the model is in evaluation mode | |
model.eval() | |
logging.info("Model set to evaluation mode") | |
# Define the prediction function | |
def predict(image): | |
logging.info("Starting prediction") | |
logging.info(f"Input image shape: {image.size}") | |
# Preprocess the image | |
logging.info("Preprocessing image") | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}") | |
logging.info("Running inference") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.nn.functional.softmax(logits[0], dim=0) | |
logging.info(f"Raw logits: {logits}") | |
logging.info(f"Probabilities: {probabilities}") | |
# Prepare the output dictionary | |
result = {labels[i]: float(probabilities[i]) for i in range(len(labels))} | |
logging.info(f"Prediction result: {result}") | |
return result | |
# Set up the Gradio Interface | |
logging.info("Setting up Gradio interface") | |
gradio_app = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=6), | |
title="Shirts Buttons Classifier" | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
logging.info("Launching the app") | |
gradio_app.launch() | |