text / app.py
yiw's picture
Update app.py
0730cf9
raw
history blame
1.72 kB
import streamlit as st
from transformers import pipeline
from PIL import Image
import numpy as np
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
def main():
st.title("text-classification")
with st.form("text_field"):
uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg','jpg'])
clicked = st.form_submit_button("Submit")
if clicked:
uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg','jpg'])
if uploaded_file!=None:
img=Image.open(uploaded_file)
extractor = AutoFeatureExtractor.from_pretrained("yangy50/garbage-classification")
model = AutoModelForImageClassification.from_pretrained("yangy50/garbage-classification")
inputs = extractor(img,return_tensors="pt")
outputs = model(**inputs)
label_num=outputs.logits.softmax(1).argmax(1)
label_num=label_num.item()
st.write("The prediction class is:")
if label_num==0:
st.write("cardboard")
elif label_num==1:
st.write("glass")
elif label_num==2:
st.write("metal")
elif label_num==3:
st.write("paper")
elif label_num==4:
st.write("plastic")
else:
st.write("trash")
st.image(img)
if __name__ == "__main__":
main()