text / app.py
yiw's picture
Update app.py
7312dc7
raw
history blame
1.32 kB
import streamlit as st
from transformers import pipeline
from PIL import Image
import numpy as np
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
def main():
st.title("text-classification")
with st.form("text_field"):
uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg','jpg'])
if uploaded_file!=None:
img=Image.open(uploaded_file)
extractor = AutoFeatureExtractor.from_pretrained("yangy50/garbage-classification")
model = AutoModelForImageClassification.from_pretrained("yangy50/garbage-classification")
inputs = extractor(img,return_tensors="pt")
outputs = model(**inputs)
label_num=outputs.logits.softmax(1).argmax(1)
label_num=label_num.item()
st.write("The prediction class is:")
if label_num==0:
st.write("cardboard")
elif label_num==1:
st.write("glass")
elif label_num==2:
st.write("metal")
elif label_num==3:
st.write("paper")
elif label_num==4:
st.write("plastic")
else:
st.write("trash")
st.image(img)
if __name__ == "__main__":
main()