File size: 2,845 Bytes
a73e6b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoConfig
import gradio as gr
from PIL import Image
import os
import logging
from safetensors.torch import load_file  # Import safetensors loading function

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define the directory containing the model files
model_dir = "."  # Use current directory

# Define paths to the specific model files
model_path = os.path.join(model_dir, "model.safetensors")
config_path = os.path.join(model_dir, "config.json")
preprocessor_path = os.path.join(model_dir, "preprocessor_config.json")

# Check if all required files exist
for path in [model_path, config_path, preprocessor_path]:
    if not os.path.exists(path):
        logging.error(f"File not found: {path}")
        raise FileNotFoundError(f"Required file not found: {path}")
    else:
        logging.info(f"Found file: {path}")

# Load the configuration
config = AutoConfig.from_pretrained(config_path)

# Ensure the labels are consistent with the model's config
labels = list(config.id2label.values())
logging.info(f"Labels: {labels}")

# Load the feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained(preprocessor_path)

# Load the model using the safetensors file
state_dict = load_file(model_path)  # Use safetensors to load the model weights
model = ViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path=None,
    config=config,
    state_dict=state_dict
)

# Ensure the model is in evaluation mode
model.eval()
logging.info("Model set to evaluation mode")

# Define the prediction function
def predict(image):
    logging.info("Starting prediction")
    logging.info(f"Input image shape: {image.size}")
    
    # Preprocess the image
    logging.info("Preprocessing image")
    inputs = feature_extractor(images=image, return_tensors="pt")
    logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}")
    
    logging.info("Running inference")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.nn.functional.softmax(logits[0], dim=0)
    
    logging.info(f"Raw logits: {logits}")
    logging.info(f"Probabilities: {probabilities}")
    
    # Prepare the output dictionary
    result = {labels[i]: float(probabilities[i]) for i in range(len(labels))}
    logging.info(f"Prediction result: {result}")
    
    return result

# Set up the Gradio Interface
logging.info("Setting up Gradio interface")
gradio_app = gr.Interface(
    fn=predict, 
    inputs=gr.Image(type="pil"), 
    outputs=gr.Label(num_top_classes=6),
    title="Shoes Height Classifier"
)

# Launch the app
if __name__ == "__main__":
    logging.info("Launching the app")
    gradio_app.launch()