Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
from PIL import Image | |
import requests | |
import ModelClass | |
from glob import glob | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
def load_model(): | |
return ModelClass.get_model() | |
def get_images(): | |
l = glob('./inputs/*') | |
l = {i.split('/')[-1]: i for i in l} | |
return l | |
def infer(img): | |
image = img.convert('RGB') | |
image = ModelClass.get_transform()(image) | |
image = image.unsqueeze(dim=0) | |
model = load_model() | |
model.eval() | |
with torch.no_grad(): | |
out = model(image) | |
out = nn.Softmax()(out).squeeze() | |
return out | |
st.set_page_config( | |
page_title="Whale Identification", | |
page_icon="🧊", | |
layout="centered", | |
initial_sidebar_state="expanded", | |
menu_items={ | |
'Get Help': 'https://www.extremelycoolapp.com/help', | |
'Report a bug': "https://www.extremelycoolapp.com/bug", | |
'About': """ | |
# This is a header. This is an *extremely* cool app! | |
How how are you doin. | |
--- | |
I am fine | |
<style> | |
</style> | |
""" | |
} | |
) | |
# fix sidebar | |
st.markdown(""" | |
<style> | |
.css-vk3wp9 { | |
background-color: rgb(255 255 255); | |
} | |
.css-18l0hbk { | |
padding: 0.34rem 1.2rem !important; | |
margin: 0.125rem 2rem; | |
} | |
.css-nziaof { | |
padding: 0.34rem 1.2rem !important; | |
margin: 0.125rem 2rem; | |
background-color: rgb(181 197 227 / 18%) !important; | |
} | |
.css-1y4p8pa { | |
padding: 3rem 1rem 10rem; | |
max-width: 58rem; | |
} | |
</style> | |
""", unsafe_allow_html=True | |
) | |
hide_st_style = """ | |
<style> | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
header {visibility: hidden;} | |
</style> | |
""" | |
#st.markdown(hide_st_style, unsafe_allow_html=True) | |
def predict(image): | |
# Dummy prediction | |
classes = ['cat', 'dog'] | |
prediction = np.random.rand(len(classes)) | |
prediction /= np.sum(prediction) | |
return dict(zip(classes, prediction)) | |
def app(): | |
st.title('ActionNet') | |
# st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai/<username>/<project_name>?workspace=user-<username>)") | |
st.markdown('Human Action Recognition using CNN: A Conputer Vision project that trains a ResNet model to classify human activities. The dataset contains 15 activity classes, and the model predicts the activity from input images.') | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
test_images = get_images() | |
test_image = st.selectbox('Or choose a test image', list(test_images.keys())) | |
st.markdown('#### Selected Image') | |
left_column, right_column = st.columns([1.5, 2.5], gap="medium") | |
with left_column: | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, use_column_width=True) | |
else: | |
image_url = test_images[test_image] | |
image = Image.open(image_url) | |
st.image(image, use_column_width=True) | |
if st.button('✨ Get prediction from AI', type='primary'): | |
spacer = st.empty() | |
res = infer(image) | |
prob = res.numpy() | |
idx = np.argpartition(prob, -4)[-4:] | |
right_column.markdown('#### Results') | |
idx = list(idx) | |
for i in idx: | |
class_name = ModelClass.get_class(i).replace('_', ' ').capitalize() | |
class_probability = prob[i].astype(float) | |
right_column.write(f'{class_name}: {class_probability:.2%}') | |
right_column.progress(class_probability) | |
st.markdown("---") | |
st.markdown("Built by [Shamim Ahamed](https://www.shamimahamed.com/). Data provided by [aiplanet](https://aiplanet.com/challenges/data-sprint-76-human-activity-recognition/233/overview/about)") | |
app() |