File size: 1,720 Bytes
997da96
864cbcc
4557b21
7312dc7
 
997da96
0ce8030
997da96
aff340b
997da96
864cbcc
4557b21
0730cf9
 
 
 
fb103a1
0730cf9
fb103a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0730cf9
7312dc7
997da96
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import streamlit as st
from transformers import pipeline
from PIL import Image
import numpy as np
from transformers import AutoFeatureExtractor, AutoModelForImageClassification

classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
def main():
    st.title("text-classification")

    with st.form("text_field"):
        uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg','jpg'])
        clicked = st.form_submit_button("Submit")
        if clicked:
            uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg','jpg'])
    
            if uploaded_file!=None:

                img=Image.open(uploaded_file)
        
                extractor = AutoFeatureExtractor.from_pretrained("yangy50/garbage-classification")
                model = AutoModelForImageClassification.from_pretrained("yangy50/garbage-classification")
        
                inputs = extractor(img,return_tensors="pt")
                outputs = model(**inputs)
                label_num=outputs.logits.softmax(1).argmax(1)
                label_num=label_num.item()
        
                st.write("The prediction class is:")
        
                if label_num==0:
                    st.write("cardboard")
                elif label_num==1:
                    st.write("glass")
                elif label_num==2:
                    st.write("metal")
                elif label_num==3:
                    st.write("paper")
                elif label_num==4:
                    st.write("plastic")
                else:
                    st.write("trash")
        
                st.image(img)
                   
    
if __name__ == "__main__":
    main()