yiw commited on
Commit
7312dc7
·
1 Parent(s): 8267b4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -5
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
 
 
4
 
5
  classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
6
  def main():
@@ -8,11 +10,34 @@ def main():
8
 
9
  with st.form("text_field"):
10
  uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg','jpg'])
11
- # clicked==True only when the button is clicked
12
- clicked = st.form_submit_button("Submit")
13
- if clicked:
14
- img=Image.open(uploaded_file)
15
- results = classifier([img])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  st.image(img)
17
 
18
  if __name__ == "__main__":
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
+ import numpy as np
5
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
6
 
7
  classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
8
  def main():
 
10
 
11
  with st.form("text_field"):
12
  uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg','jpg'])
13
+
14
+ if uploaded_file!=None:
15
+
16
+ img=Image.open(uploaded_file)
17
+
18
+ extractor = AutoFeatureExtractor.from_pretrained("yangy50/garbage-classification")
19
+ model = AutoModelForImageClassification.from_pretrained("yangy50/garbage-classification")
20
+
21
+ inputs = extractor(img,return_tensors="pt")
22
+ outputs = model(**inputs)
23
+ label_num=outputs.logits.softmax(1).argmax(1)
24
+ label_num=label_num.item()
25
+
26
+ st.write("The prediction class is:")
27
+
28
+ if label_num==0:
29
+ st.write("cardboard")
30
+ elif label_num==1:
31
+ st.write("glass")
32
+ elif label_num==2:
33
+ st.write("metal")
34
+ elif label_num==3:
35
+ st.write("paper")
36
+ elif label_num==4:
37
+ st.write("plastic")
38
+ else:
39
+ st.write("trash")
40
+
41
  st.image(img)
42
 
43
  if __name__ == "__main__":