ActionNet / app.py
DataRaptor's picture
Upload 64 files
845eb37
raw
history blame
4.23 kB
import streamlit as st
import numpy as np
from PIL import Image
import requests
import ModelClass
from glob import glob
import torch
import torch.nn as nn
import numpy as np
@st.cache_resource
def load_model():
return ModelClass.get_model()
@st.cache_data
def get_images():
l = glob('./inputs/*')
l = {i.split('/')[-1]: i for i in l}
return l
def infer(img):
image = img.convert('RGB')
image = ModelClass.get_transform()(image)
image = image.unsqueeze(dim=0)
model = load_model()
model.eval()
with torch.no_grad():
out = model(image)
out = nn.Softmax()(out).squeeze()
return out
st.set_page_config(
page_title="Whale Identification",
page_icon="🧊",
layout="centered",
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'https://www.extremelycoolapp.com/help',
'Report a bug': "https://www.extremelycoolapp.com/bug",
'About': """
# This is a header. This is an *extremely* cool app!
How how are you doin.
---
I am fine
<style>
</style>
"""
}
)
# fix sidebar
st.markdown("""
<style>
.css-vk3wp9 {
background-color: rgb(255 255 255);
}
.css-18l0hbk {
padding: 0.34rem 1.2rem !important;
margin: 0.125rem 2rem;
}
.css-nziaof {
padding: 0.34rem 1.2rem !important;
margin: 0.125rem 2rem;
background-color: rgb(181 197 227 / 18%) !important;
}
.css-1y4p8pa {
padding: 3rem 1rem 10rem;
max-width: 58rem;
}
</style>
""", unsafe_allow_html=True
)
hide_st_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
</style>
"""
#st.markdown(hide_st_style, unsafe_allow_html=True)
def predict(image):
# Dummy prediction
classes = ['cat', 'dog']
prediction = np.random.rand(len(classes))
prediction /= np.sum(prediction)
return dict(zip(classes, prediction))
def app():
st.title('ActionNet')
# st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai/<username>/<project_name>?workspace=user-<username>)")
st.markdown('Human Action Recognition using CNN: A Conputer Vision project that trains a ResNet model to classify human activities. The dataset contains 15 activity classes, and the model predicts the activity from input images.')
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
test_images = get_images()
test_image = st.selectbox('Or choose a test image', list(test_images.keys()))
st.markdown('#### Selected Image')
left_column, right_column = st.columns([1.5, 2.5], gap="medium")
with left_column:
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, use_column_width=True)
else:
image_url = test_images[test_image]
image = Image.open(image_url)
st.image(image, use_column_width=True)
if st.button('✨ Get prediction from AI', type='primary'):
spacer = st.empty()
res = infer(image)
prob = res.numpy()
idx = np.argpartition(prob, -4)[-4:]
right_column.markdown('#### Results')
idx = list(idx)
for i in idx:
class_name = ModelClass.get_class(i).replace('_', ' ').capitalize()
class_probability = prob[i].astype(float)
right_column.write(f'{class_name}: {class_probability:.2%}')
right_column.progress(class_probability)
st.markdown("---")
st.markdown("Built by [Shamim Ahamed](https://www.shamimahamed.com/). Data provided by [aiplanet](https://aiplanet.com/challenges/data-sprint-76-human-activity-recognition/233/overview/about)")
app()