yiw commited on
Commit
b84ede9
·
1 Parent(s): 0730cf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -35
app.py CHANGED
@@ -1,48 +1,25 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
- import numpy as np
5
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification
 
6
 
7
- classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
8
  def main():
9
- st.title("text-classification")
10
 
11
  with st.form("text_field"):
12
- uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg','jpg'])
 
 
 
13
  clicked = st.form_submit_button("Submit")
14
  if clicked:
15
- uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg','jpg'])
16
-
17
- if uploaded_file!=None:
 
18
 
19
- img=Image.open(uploaded_file)
20
-
21
- extractor = AutoFeatureExtractor.from_pretrained("yangy50/garbage-classification")
22
- model = AutoModelForImageClassification.from_pretrained("yangy50/garbage-classification")
23
-
24
- inputs = extractor(img,return_tensors="pt")
25
- outputs = model(**inputs)
26
- label_num=outputs.logits.softmax(1).argmax(1)
27
- label_num=label_num.item()
28
-
29
- st.write("The prediction class is:")
30
-
31
- if label_num==0:
32
- st.write("cardboard")
33
- elif label_num==1:
34
- st.write("glass")
35
- elif label_num==2:
36
- st.write("metal")
37
- elif label_num==3:
38
- st.write("paper")
39
- elif label_num==4:
40
- st.write("plastic")
41
- else:
42
- st.write("trash")
43
-
44
- st.image(img)
45
-
46
 
47
  if __name__ == "__main__":
48
  main()
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
+ classifier = pipeline("image-classification", model="google/vit-base-patch16-224")
5
+ def get_img_from_url(url):
6
+ return Image.open(requests.get(url, stream=True).raw)
7
 
 
8
  def main():
9
+ st.title("Yelp review")
10
 
11
  with st.form("text_field"):
12
+ text = st.text_area('enter some text:')
13
+ url = st.text_input("URL to some image", "https://images.livemint.com/img/2022/08/01/600x338/Cat-andriyko-podilnyk-RCfi7vgJjUY-unsplash_1659328989095_1659328998370_1659328998370.jpg")
14
+ img = get_img_from_url(url)
15
+ # clicked==True only when the button is clicked
16
  clicked = st.form_submit_button("Submit")
17
  if clicked:
18
+ results = classifier([text])
19
+ st.json(results)
20
+
21
+
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  if __name__ == "__main__":
25
  main()