yolo.dog / app.py
ch-tseng's picture
update
dbb47b9
# -*- coding: utf-8 -*-
"""
-------------------------------------------------
@File Name: app.py
@Author: Luyao.zhang
@Date: 2023/5/15
@Description:
-------------------------------------------------
"""
HuggingFace = True
update_model_id = None
base_download_path = "downloaded"
if HuggingFace is False:
from dotenv import load_dotenv
load_dotenv()
from pathlib import Path
from PIL import Image
import streamlit as st
import config
from utils import load_model, infer_uploaded_image, infer_uploaded_video, infer_uploaded_webcam
import os
query_params = st.experimental_get_query_params()
#------------
if update_model_id is not None:
if HuggingFace is False:
model_name = os.getenv("m{}_name".format(update_model_id))
model_extname = os.getenv("m{}_type".format(update_model_id))
else:
model_name = st.secrets["m{}_name".format(update_model_id)]
model_extname = st.secrets["m{}_type".format(update_model_id)]
path_model = os.path.join(base_download_path, model_name + model_extname)
if os.path.exists(path_model):
try:
os.remove(path_model)
except:
print("Cannot remove", path_model)
pass
#-------------
qmodel = 'Crowded-Human'
if 'model' in query_params:
qmodel = query_params['model'][0]
# setting page layout
st.set_page_config(
page_title="YOLO.dog",
page_icon="🤖",
layout="wide",
initial_sidebar_state="expanded"
)
#------------------------
if not os.path.exists(base_download_path):
try:
os.makedirs(base_download_path)
except:
print("folder exists", base_download_path)
pass
if HuggingFace is False:
model_count = int(os.getenv("model_count"))
else:
model_count = int(st.secrets["model_count"])
model_info = {}
models_list = []
for i in range(0, model_count):
if HuggingFace is False:
model_name = os.getenv("m{}_name".format(i))
gdrive_id = os.getenv("m{}_griv".format(i))
model_extname = os.getenv("m{}_type".format(i))
model_desc = os.getenv("m{}_desc".format(i))
else:
model_name = st.secrets["m{}_name".format(i)]
gdrive_id = st.secrets["m{}_griv".format(i)]
model_extname = st.secrets["m{}_type".format(i)]
model_desc = st.secrets["m{}_desc".format(i)]
print(i, model_name, gdrive_id, model_extname, model_desc)
path_model = os.path.join(base_download_path, model_name + model_extname)
print('path_model', path_model)
model_info.update( {model_desc:path_model} )
models_list.append(model_desc)
if not os.path.exists(path_model):
download_link = "https://drive.google.com/uc?export=download&confirm=t&id={}".format(gdrive_id)
#subprocess.Popen( 'gdown {}'.format(download_link)
#if gdrive_id[:4] == "http":
print('wget -O {} --content-disposition "{}"'.format(path_model, download_link))
os.system( 'wget -O {} --content-disposition "{}"'.format(path_model, download_link))
#else:
# download_file_from_google_drive(gdrive_id, path_model)
#print('models_list', models_list)
if qmodel not in models_list:
qmodel = models_list[0]
# main page heading
#st.title( model_info[qmodel] )
# sidebar
st.sidebar.header("Model Config")
# model options
task_type = "Detection"
model_type = st.sidebar.selectbox(
"Models list",
models_list,
index=models_list.index(qmodel) )
confidence = float(st.sidebar.slider(
"Select Model Confidence", 10, 100, 5)) / 100
if model_type:
st.header('{} Model Trial'.format(model_type),divider='rainbow')
st.subheader('Use your :blue[photo/video] :sunglasses:')
model_path = model_info[model_type]
try:
print('model_path', model_path)
model = load_model(model_path)
except Exception as e:
st.error(f"Unable to load model. Please check the specified path: {model_path}")
else:
st.error("Please Select Model in Sidebar")
# image/video options
st.sidebar.header("Image/Video Config")
source_selectbox = st.sidebar.selectbox(
"Select Source",
config.SOURCES_LIST
)
source_img = None
if source_selectbox == config.SOURCES_LIST[0]: # Image
infer_uploaded_image(confidence, model)
elif source_selectbox == config.SOURCES_LIST[1]: # Video
infer_uploaded_video(confidence, model)
elif source_selectbox == config.SOURCES_LIST[2]: # Webcam
infer_uploaded_webcam(confidence, model)
else:
st.error("Currently only 'Image' and 'Video' source are implemented")