vikramjeetthakur commited on
Commit
81c25c0
·
verified ·
1 Parent(s): b3ffde7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -1,16 +1,18 @@
1
  import streamlit as st
2
  from PIL import Image
3
- from transformers import pipeline
 
4
 
5
- # Load the image classification pipeline
6
- model_name = "vikramjeetthakur/hotornot" # Replace with your model's name or path
7
- classifier = pipeline("image-classification", model=model_name)
 
8
 
9
  # Define the main function for the Streamlit app
10
  def main():
11
  st.title("Hot or Not Image Classifier")
12
 
13
- st.write("Upload an image to classify it as hot or not hot.")
14
 
15
  # Image upload
16
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
@@ -20,14 +22,18 @@ def main():
20
  image = Image.open(uploaded_file)
21
  st.image(image, caption="Uploaded Image", use_column_width=True)
22
 
23
- # Make predictions
24
- predictions = classifier(image)
25
 
26
- # Display results
27
- for prediction in predictions:
28
- label = prediction['label']
29
- score = prediction['score']
30
- st.write(f"**{label}** with confidence {score:.2f}")
 
 
 
 
31
 
32
  if __name__ == "__main__":
33
  main()
 
1
  import streamlit as st
2
  from PIL import Image
3
+ import torch
4
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
5
 
6
+ # Load a pre-trained model and feature extractor
7
+ model_name = "facebook/wide_resnet50_2" # Using a general model
8
+ model = AutoModelForImageClassification.from_pretrained(model_name)
9
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
10
 
11
  # Define the main function for the Streamlit app
12
  def main():
13
  st.title("Hot or Not Image Classifier")
14
 
15
+ st.write("Upload an image to classify it.")
16
 
17
  # Image upload
18
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
22
  image = Image.open(uploaded_file)
23
  st.image(image, caption="Uploaded Image", use_column_width=True)
24
 
25
+ # Preprocess the image
26
+ inputs = feature_extractor(images=image, return_tensors="pt")
27
 
28
+ # Make predictions
29
+ with torch.no_grad():
30
+ outputs = model(**inputs)
31
+ logits = outputs.logits # Get the logits
32
+ class_idx = logits.argmax(-1).item() # Get the index of the highest probability
33
+
34
+ # Display results based on class index
35
+ st.write(f"Predicted class index: {class_idx}")
36
+ st.write(f"Predicted class label: {model.config.id2label[class_idx]}")
37
 
38
  if __name__ == "__main__":
39
  main()