import streamlit as st from transformers import ViTForImageClassification, ViTFeatureExtractor from PIL import Image import torch import matplotlib.pyplot as plt # Define the repository ID repo_id = "Hammad712/5-Flower-Types-Classification-VIT-Model" # Load the model and feature extractor model = ViTForImageClassification.from_pretrained(repo_id) feature_extractor = ViTFeatureExtractor.from_pretrained(repo_id) # Define the class names dictionary class_names = {0: 'Lilly', 1: 'Lotus', 2: 'Orchid', 3: 'Sunflower', 4: 'Tulip'} # Define the inference function def predict(image): inputs = feature_extractor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist() predicted_class_idx = logits.argmax(-1).item() predicted_class_name = class_names[predicted_class_idx] return probabilities, predicted_class_name # Streamlit app st.title("Flower Type Classification") st.write("Upload an image of a flower to classify its type.") # Upload image uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Display the uploaded image image = Image.open(uploaded_file).convert("RGB") st.image(image, caption='Uploaded Image.', use_column_width=True) # Predict the class of the image probabilities, predicted_class = predict(image) # Display the probabilities in a bar chart fig, ax = plt.subplots() ax.bar(class_names.values(), probabilities) ax.set_ylabel('Probability') ax.set_xlabel('Class') ax.set_title('Class Probabilities') st.pyplot(fig) # Display the predicted class st.write(f"Predicted class: **{predicted_class}**")