|
import streamlit as st |
|
from transformers import pipeline |
|
from PIL import Image |
|
import numpy as np |
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
|
|
classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english") |
|
def main(): |
|
st.title("text-classification") |
|
|
|
with st.form("text_field"): |
|
uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg','jpg']) |
|
clicked = st.form_submit_button("Submit") |
|
if clicked: |
|
uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg','jpg']) |
|
|
|
if uploaded_file!=None: |
|
|
|
img=Image.open(uploaded_file) |
|
|
|
extractor = AutoFeatureExtractor.from_pretrained("yangy50/garbage-classification") |
|
model = AutoModelForImageClassification.from_pretrained("yangy50/garbage-classification") |
|
|
|
inputs = extractor(img,return_tensors="pt") |
|
outputs = model(**inputs) |
|
label_num=outputs.logits.softmax(1).argmax(1) |
|
label_num=label_num.item() |
|
|
|
st.write("The prediction class is:") |
|
|
|
if label_num==0: |
|
st.write("cardboard") |
|
elif label_num==1: |
|
st.write("glass") |
|
elif label_num==2: |
|
st.write("metal") |
|
elif label_num==3: |
|
st.write("paper") |
|
elif label_num==4: |
|
st.write("plastic") |
|
else: |
|
st.write("trash") |
|
|
|
st.image(img) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |