Hrishikesh332 commited on
Commit
0cbbcf7
·
1 Parent(s): 326eb74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import streamlit as st
 
 
2
  from PIL import Image
3
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
4
 
@@ -7,6 +9,9 @@ def load_image(img):
7
  return im
8
  size=20
9
 
 
 
 
10
  st.markdown("<h1 style='text-align: center;'>Memeter 💬</h1>", unsafe_allow_html=True)
11
  st.markdown("---")
12
  with st.sidebar:
@@ -15,12 +20,21 @@ with st.sidebar:
15
  Memeter is an application used for the classification of whether the images provided is meme or not meme
16
  ''', unsafe_allow_html=False)
17
 
18
- img = st.file_uploader("Choose a Image:")
 
 
 
 
 
 
 
 
19
  if img is not None:
20
- st.write(img)
21
- extractor = AutoFeatureExtractor.from_pretrained("Hrishikesh332/autotrain-meme-classification-42897109437")
22
- model = AutoModelForImageClassification.from_pretrained("Hrishikesh332/autotrain-meme-classification-42897109437")
23
- pred=model(img)
24
- st.code(pred)
 
25
 
26
 
 
1
  import streamlit as st
2
+ import requests
3
+ from io import BytesIO
4
  from PIL import Image
5
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
6
 
 
9
  return im
10
  size=20
11
 
12
+ extractor = AutoFeatureExtractor.from_pretrained("Hrishikesh332/autotrain-meme-classification-42897109437")
13
+ model = AutoModelForImageClassification.from_pretrained("Hrishikesh332/autotrain-meme-classification-42897109437")
14
+
15
  st.markdown("<h1 style='text-align: center;'>Memeter 💬</h1>", unsafe_allow_html=True)
16
  st.markdown("---")
17
  with st.sidebar:
 
20
  Memeter is an application used for the classification of whether the images provided is meme or not meme
21
  ''', unsafe_allow_html=False)
22
 
23
+ img = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
24
+
25
+ def predict(image):
26
+ inputs = extractor(images=image, return_tensors="pt")
27
+
28
+ outputs = model(**inputs)
29
+ scores = outputs.logits.detach().numpy()
30
+ return scores
31
+
32
  if img is not None:
33
+ try:
34
+ image = Image.open(BytesIO(img.read()))
35
+ s = predict(image)
36
+ st.write("Value:", s)
37
+ except:
38
+ st.write("Pleas do upload the image in the correct format!")
39
 
40